]> git.scottworley.com Git - pinch/blob - pinch.py
Move git_cache out to a separate library
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tarfile
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Callable,
23 Dict,
24 Iterable,
25 List,
26 Mapping,
27 NamedTuple,
28 NewType,
29 Optional,
30 Set,
31 Tuple,
32 Type,
33 TypeVar,
34 Union,
35 )
36
37 import git_cache
38
39 # Use xdg module when it's less painful to have as a dependency
40
41
42 class XDG(NamedTuple):
43 XDG_CACHE_HOME: str
44
45
46 xdg = XDG(
47 XDG_CACHE_HOME=os.getenv(
48 'XDG_CACHE_HOME',
49 os.path.expanduser('~/.cache')))
50
51
52 class VerificationError(Exception):
53 pass
54
55
56 class Verification:
57
58 def __init__(self) -> None:
59 self.line_length = 0
60
61 def status(self, s: str) -> None:
62 print(s, end=' ', file=sys.stderr, flush=True)
63 self.line_length += 1 + len(s) # Unicode??
64
65 @staticmethod
66 def _color(s: str, c: int) -> str:
67 return '\033[%2dm%s\033[00m' % (c, s)
68
69 def result(self, r: bool) -> None:
70 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
71 length = len(message)
72 cols = shutil.get_terminal_size().columns or 80
73 pad = (cols - (self.line_length + length)) % cols
74 print(' ' * pad + self._color(message, color), file=sys.stderr)
75 self.line_length = 0
76 if not r:
77 raise VerificationError()
78
79 def check(self, s: str, r: bool) -> None:
80 self.status(s)
81 self.result(r)
82
83 def ok(self) -> None:
84 self.result(True)
85
86
87 Digest16 = NewType('Digest16', str)
88 Digest32 = NewType('Digest32', str)
89
90
91 class ChannelTableEntry(types.SimpleNamespace):
92 absolute_url: str
93 digest: Digest16
94 file: str
95 size: int
96 url: str
97
98
99 class AliasPin(NamedTuple):
100 pass
101
102
103 class SymlinkPin(NamedTuple):
104 @property
105 def release_name(self) -> str:
106 return 'link'
107
108
109 class GitPin(NamedTuple):
110 git_revision: str
111 release_name: str
112
113
114 class ChannelPin(NamedTuple):
115 git_revision: str
116 release_name: str
117 tarball_url: str
118 tarball_sha256: str
119
120
121 Pin = Union[AliasPin, SymlinkPin, GitPin, ChannelPin]
122
123
124 def copy_to_nix_store(v: Verification, filename: str) -> str:
125 v.status('Putting tarball in Nix store')
126 process = subprocess.run(
127 ['nix-store', '--add', filename], stdout=subprocess.PIPE)
128 v.result(process.returncode == 0)
129 return process.stdout.decode().strip() # type: ignore # (for old mypy)
130
131
132 def symlink_archive(v: Verification, path: str) -> str:
133 with tempfile.TemporaryDirectory() as td:
134 archive_filename = os.path.join(td, 'link.tar.gz')
135 os.symlink(path, os.path.join(td, 'link'))
136 with tarfile.open(archive_filename, mode='x:gz') as t:
137 t.add(os.path.join(td, 'link'), arcname='link')
138 return copy_to_nix_store(v, archive_filename)
139
140
141 class AliasSearchPath(NamedTuple):
142 alias_of: str
143
144 # pylint: disable=no-self-use
145 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
146 return AliasPin()
147
148
149 class SymlinkSearchPath(NamedTuple):
150 path: str
151
152 # pylint: disable=no-self-use
153 def pin(self, _: Verification, __: Optional[Pin]) -> SymlinkPin:
154 return SymlinkPin()
155
156 def fetch(self, v: Verification, _: Pin) -> str:
157 return symlink_archive(v, self.path)
158
159
160 class GitSearchPath(NamedTuple):
161 git_ref: str
162 git_repo: str
163
164 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
165 _, new_revision = git_cache.fetch(self.git_repo, self.git_ref)
166 if old_pin is not None:
167 assert isinstance(old_pin, GitPin)
168 verify_git_ancestry(v, self, old_pin.git_revision, new_revision)
169 return GitPin(release_name=git_revision_name(v, self, new_revision),
170 git_revision=new_revision)
171
172 def fetch(self, v: Verification, pin: Pin) -> str:
173 assert isinstance(pin, GitPin)
174 git_cache.ensure_rev_available(
175 self.git_repo, self.git_ref, pin.git_revision)
176 return git_get_tarball(v, self, pin)
177
178
179 class ChannelSearchPath(NamedTuple):
180 channel_url: str
181 git_ref: str
182 git_repo: str
183
184 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
185 if old_pin is not None:
186 assert isinstance(old_pin, ChannelPin)
187
188 channel_html, forwarded_url = fetch_channel(v, self)
189 table, new_gitpin = parse_channel(v, channel_html)
190 if old_pin is not None and old_pin.git_revision == new_gitpin.git_revision:
191 return old_pin
192 fetch_resources(v, new_gitpin, forwarded_url, table)
193 git_cache.ensure_rev_available(
194 self.git_repo, self.git_ref, new_gitpin.git_revision)
195 if old_pin is not None:
196 verify_git_ancestry(
197 v, self, old_pin.git_revision, new_gitpin.git_revision)
198 check_channel_contents(v, self, table, new_gitpin)
199 return ChannelPin(
200 release_name=new_gitpin.release_name,
201 tarball_url=table['nixexprs.tar.xz'].absolute_url,
202 tarball_sha256=table['nixexprs.tar.xz'].digest,
203 git_revision=new_gitpin.git_revision)
204
205 # pylint: disable=no-self-use
206 def fetch(self, v: Verification, pin: Pin) -> str:
207 assert isinstance(pin, ChannelPin)
208
209 return fetch_with_nix_prefetch_url(
210 v, pin.tarball_url, Digest16(pin.tarball_sha256))
211
212
213 SearchPath = Union[AliasSearchPath,
214 SymlinkSearchPath,
215 GitSearchPath,
216 ChannelSearchPath]
217 TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
218
219
220 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
221
222 def throw(error: OSError) -> None:
223 raise error
224
225 def join(x: str, y: str) -> str:
226 return y if x == '.' else os.path.join(x, y)
227
228 def recursive_files(d: str) -> Iterable[str]:
229 all_files: List[str] = []
230 for path, dirs, files in os.walk(d, onerror=throw):
231 rel = os.path.relpath(path, start=d)
232 all_files.extend(join(rel, f) for f in files)
233 for dir_or_link in dirs:
234 if os.path.islink(join(path, dir_or_link)):
235 all_files.append(join(rel, dir_or_link))
236 return all_files
237
238 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
239 return (f for f in files if not f.startswith('.git/'))
240
241 files = functools.reduce(
242 operator.or_, (set(
243 exclude_dot_git(
244 recursive_files(x))) for x in [a, b]))
245 return filecmp.cmpfiles(a, b, files, shallow=False)
246
247
248 def fetch_channel(
249 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
250 v.status('Fetching channel')
251 request = urllib.request.urlopen(channel.channel_url, timeout=10)
252 channel_html = request.read().decode()
253 forwarded_url = request.geturl()
254 v.result(request.status == 200) # type: ignore # (for old mypy)
255 v.check('Got forwarded', channel.channel_url != forwarded_url)
256 return channel_html, forwarded_url
257
258
259 def parse_channel(v: Verification, channel_html: str) \
260 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
261 v.status('Parsing channel description as XML')
262 d = xml.dom.minidom.parseString(channel_html)
263 v.ok()
264
265 v.status('Extracting release name:')
266 title_name = d.getElementsByTagName(
267 'title')[0].firstChild.nodeValue.split()[2]
268 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
269 v.status(title_name)
270 v.result(title_name == h1_name)
271
272 v.status('Extracting git commit:')
273 git_commit_node = d.getElementsByTagName('tt')[0]
274 git_revision = git_commit_node.firstChild.nodeValue
275 v.status(git_revision)
276 v.ok()
277 v.status('Verifying git commit label')
278 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
279
280 v.status('Parsing table')
281 table: Dict[str, ChannelTableEntry] = {}
282 for row in d.getElementsByTagName('tr')[1:]:
283 name = row.childNodes[0].firstChild.firstChild.nodeValue
284 url = row.childNodes[0].firstChild.getAttribute('href')
285 size = int(row.childNodes[1].firstChild.nodeValue)
286 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
287 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
288 v.ok()
289 return table, GitPin(release_name=title_name, git_revision=git_revision)
290
291
292 def digest_string(s: bytes) -> Digest16:
293 return Digest16(hashlib.sha256(s).hexdigest())
294
295
296 def digest_file(filename: str) -> Digest16:
297 hasher = hashlib.sha256()
298 with open(filename, 'rb') as f:
299 # pylint: disable=cell-var-from-loop
300 for block in iter(lambda: f.read(4096), b''):
301 hasher.update(block)
302 return Digest16(hasher.hexdigest())
303
304
305 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
306 v.status('Converting digest to base16')
307 process = subprocess.run(
308 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
309 v.result(process.returncode == 0)
310 return Digest16(process.stdout.decode().strip())
311
312
313 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
314 v.status('Converting digest to base32')
315 process = subprocess.run(
316 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
317 v.result(process.returncode == 0)
318 return Digest32(process.stdout.decode().strip())
319
320
321 def fetch_with_nix_prefetch_url(
322 v: Verification,
323 url: str,
324 digest: Digest16) -> str:
325 v.status('Fetching %s' % url)
326 process = subprocess.run(
327 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
328 v.result(process.returncode == 0)
329 prefetch_digest, path, empty = process.stdout.decode().split('\n')
330 assert empty == ''
331 v.check("Verifying nix-prefetch-url's digest",
332 to_Digest16(v, Digest32(prefetch_digest)) == digest)
333 v.status("Verifying file digest")
334 file_digest = digest_file(path)
335 v.result(file_digest == digest)
336 return path # type: ignore # (for old mypy)
337
338
339 def fetch_resources(
340 v: Verification,
341 pin: GitPin,
342 forwarded_url: str,
343 table: Dict[str, ChannelTableEntry]) -> None:
344 for resource in ['git-revision', 'nixexprs.tar.xz']:
345 fields = table[resource]
346 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
347 fields.file = fetch_with_nix_prefetch_url(
348 v, fields.absolute_url, fields.digest)
349 v.status('Verifying git commit on main page matches git commit in table')
350 v.result(open(table['git-revision'].file).read(999) == pin.git_revision)
351
352
353 def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
354 return os.path.join(
355 xdg.XDG_CACHE_HOME,
356 'pinch/git-tarball',
357 '%s-%s-%s' %
358 (digest_string(channel.git_repo.encode()),
359 pin.git_revision,
360 pin.release_name))
361
362
363 def verify_git_ancestry(
364 v: Verification,
365 channel: TarrableSearchPath,
366 old_revision: str,
367 new_revision: str) -> None:
368 cachedir = git_cache.git_cachedir(channel.git_repo)
369 v.status('Verifying rev is an ancestor of previous rev %s' % old_revision)
370 process = subprocess.run(['git',
371 '-C',
372 cachedir,
373 'merge-base',
374 '--is-ancestor',
375 old_revision,
376 new_revision])
377 v.result(process.returncode == 0)
378
379
380 def compare_tarball_and_git(
381 v: Verification,
382 pin: GitPin,
383 channel_contents: str,
384 git_contents: str) -> None:
385 v.status('Comparing channel tarball with git checkout')
386 match, mismatch, errors = compare(os.path.join(
387 channel_contents, pin.release_name), git_contents)
388 v.ok()
389 v.check('%d files match' % len(match), len(match) > 0)
390 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
391 expected_errors = [
392 '.git-revision',
393 '.version-suffix',
394 'nixpkgs',
395 'programs.sqlite',
396 'svn-revision']
397 benign_errors = []
398 for ee in expected_errors:
399 if ee in errors:
400 errors.remove(ee)
401 benign_errors.append(ee)
402 v.check(
403 '%d unexpected incomparable files' %
404 len(errors),
405 len(errors) == 0)
406 v.check(
407 '(%d of %d expected incomparable files)' %
408 (len(benign_errors),
409 len(expected_errors)),
410 len(benign_errors) == len(expected_errors))
411
412
413 def extract_tarball(
414 v: Verification,
415 table: Dict[str, ChannelTableEntry],
416 dest: str) -> None:
417 v.status('Extracting tarball %s' % table['nixexprs.tar.xz'].file)
418 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
419 v.ok()
420
421
422 def git_checkout(
423 v: Verification,
424 channel: TarrableSearchPath,
425 pin: GitPin,
426 dest: str) -> None:
427 v.status('Checking out corresponding git revision')
428 git = subprocess.Popen(['git',
429 '-C',
430 git_cache.git_cachedir(channel.git_repo),
431 'archive',
432 pin.git_revision],
433 stdout=subprocess.PIPE)
434 tar = subprocess.Popen(
435 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
436 if git.stdout:
437 git.stdout.close()
438 tar.wait()
439 git.wait()
440 v.result(git.returncode == 0 and tar.returncode == 0)
441
442
443 def git_get_tarball(
444 v: Verification,
445 channel: TarrableSearchPath,
446 pin: GitPin) -> str:
447 cache_file = tarball_cache_file(channel, pin)
448 if os.path.exists(cache_file):
449 cached_tarball = open(cache_file).read(9999)
450 if os.path.exists(cached_tarball):
451 return cached_tarball
452
453 with tempfile.TemporaryDirectory() as output_dir:
454 output_filename = os.path.join(
455 output_dir, pin.release_name + '.tar.xz')
456 with open(output_filename, 'w') as output_file:
457 v.status(
458 'Generating tarball for git revision %s' %
459 pin.git_revision)
460 git = subprocess.Popen(['git',
461 '-C',
462 git_cache.git_cachedir(channel.git_repo),
463 'archive',
464 '--prefix=%s/' % pin.release_name,
465 pin.git_revision],
466 stdout=subprocess.PIPE)
467 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
468 xz.wait()
469 git.wait()
470 v.result(git.returncode == 0 and xz.returncode == 0)
471
472 store_tarball = copy_to_nix_store(v, output_filename)
473
474 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
475 open(cache_file, 'w').write(store_tarball)
476 return store_tarball # type: ignore # (for old mypy)
477
478
479 def check_channel_metadata(
480 v: Verification,
481 pin: GitPin,
482 channel_contents: str) -> None:
483 v.status('Verifying git commit in channel tarball')
484 v.result(
485 open(
486 os.path.join(
487 channel_contents,
488 pin.release_name,
489 '.git-revision')).read(999) == pin.git_revision)
490
491 v.status(
492 'Verifying version-suffix is a suffix of release name %s:' %
493 pin.release_name)
494 version_suffix = open(
495 os.path.join(
496 channel_contents,
497 pin.release_name,
498 '.version-suffix')).read(999)
499 v.status(version_suffix)
500 v.result(pin.release_name.endswith(version_suffix))
501
502
503 def check_channel_contents(
504 v: Verification,
505 channel: TarrableSearchPath,
506 table: Dict[str, ChannelTableEntry],
507 pin: GitPin) -> None:
508 with tempfile.TemporaryDirectory() as channel_contents, \
509 tempfile.TemporaryDirectory() as git_contents:
510
511 extract_tarball(v, table, channel_contents)
512 check_channel_metadata(v, pin, channel_contents)
513
514 git_checkout(v, channel, pin, git_contents)
515
516 compare_tarball_and_git(v, pin, channel_contents, git_contents)
517
518 v.status('Removing temporary directories')
519 v.ok()
520
521
522 def git_revision_name(
523 v: Verification,
524 channel: TarrableSearchPath,
525 git_revision: str) -> str:
526 v.status('Getting commit date')
527 process = subprocess.run(['git',
528 '-C',
529 git_cache.git_cachedir(channel.git_repo),
530 'log',
531 '-n1',
532 '--format=%ct-%h',
533 '--abbrev=11',
534 '--no-show-signature',
535 git_revision],
536 stdout=subprocess.PIPE)
537 v.result(process.returncode == 0 and process.stdout != b'')
538 return '%s-%s' % (os.path.basename(channel.git_repo),
539 process.stdout.decode().strip())
540
541
542 K = TypeVar('K')
543 V = TypeVar('V')
544
545
546 def partition_dict(pred: Callable[[K, V], bool],
547 d: Dict[K, V]) -> Tuple[Dict[K, V], Dict[K, V]]:
548 selected: Dict[K, V] = {}
549 remaining: Dict[K, V] = {}
550 for k, v in d.items():
551 if pred(k, v):
552 selected[k] = v
553 else:
554 remaining[k] = v
555 return selected, remaining
556
557
558 def filter_dict(d: Dict[K, V], fields: Set[K]
559 ) -> Tuple[Dict[K, V], Dict[K, V]]:
560 return partition_dict(lambda k, v: k in fields, d)
561
562
563 def read_config_section(
564 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
565 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
566 'alias': (AliasSearchPath, AliasPin),
567 'channel': (ChannelSearchPath, ChannelPin),
568 'git': (GitSearchPath, GitPin),
569 'symlink': (SymlinkSearchPath, SymlinkPin),
570 }
571 SP, P = mapping[conf['type']]
572 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
573 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
574 # Error suppression works around https://github.com/python/mypy/issues/9007
575 pin_present = pin_fields != {} or P._fields == ()
576 pin = P(**pin_fields) if pin_present else None # type: ignore
577 return SP(**remaining_fields), pin
578
579
580 def read_pinned_config_section(
581 section: str, conf: configparser.SectionProxy) -> Tuple[SearchPath, Pin]:
582 sp, pin = read_config_section(conf)
583 if pin is None:
584 raise Exception(
585 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
586 section)
587 return sp, pin
588
589
590 def read_config(filename: str) -> configparser.ConfigParser:
591 config = configparser.ConfigParser()
592 config.read_file(open(filename), filename)
593 return config
594
595
596 def read_config_files(
597 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
598 merged_config: Dict[str, configparser.SectionProxy] = {}
599 for file in filenames:
600 config = read_config(file)
601 for section in config.sections():
602 if section in merged_config:
603 raise Exception('Duplicate channel "%s"' % section)
604 merged_config[section] = config[section]
605 return merged_config
606
607
608 def pinCommand(args: argparse.Namespace) -> None:
609 v = Verification()
610 config = read_config(args.channels_file)
611 for section in config.sections():
612 if args.channels and section not in args.channels:
613 continue
614
615 sp, old_pin = read_config_section(config[section])
616
617 config[section].update(sp.pin(v, old_pin)._asdict())
618
619 with open(args.channels_file, 'w') as configfile:
620 config.write(configfile)
621
622
623 def updateCommand(args: argparse.Namespace) -> None:
624 v = Verification()
625 exprs: Dict[str, str] = {}
626 config = {
627 section: read_pinned_config_section(section, conf) for section,
628 conf in read_config_files(
629 args.channels_file).items()}
630 alias, nonalias = partition_dict(
631 lambda k, v: isinstance(v[0], AliasSearchPath), config)
632
633 for section, (sp, pin) in nonalias.items():
634 assert not isinstance(sp, AliasSearchPath) # mypy can't see through
635 assert not isinstance(pin, AliasPin) # partition_dict()
636 tarball = sp.fetch(v, pin)
637 exprs[section] = (
638 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
639 (pin.release_name, tarball))
640
641 for section, (sp, pin) in alias.items():
642 assert isinstance(sp, AliasSearchPath) # For mypy
643 exprs[section] = exprs[sp.alias_of]
644
645 command = [
646 'nix-env',
647 '--profile',
648 args.profile,
649 '--show-trace',
650 '--file',
651 '<nix/unpack-channel.nix>',
652 '--install',
653 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
654 if args.dry_run:
655 print(' '.join(map(shlex.quote, command)))
656 else:
657 v.status('Installing channels with nix-env')
658 process = subprocess.run(command)
659 v.result(process.returncode == 0)
660
661
662 def main() -> None:
663 parser = argparse.ArgumentParser(prog='pinch')
664 subparsers = parser.add_subparsers(dest='mode', required=True)
665 parser_pin = subparsers.add_parser('pin')
666 parser_pin.add_argument('channels_file', type=str)
667 parser_pin.add_argument('channels', type=str, nargs='*')
668 parser_pin.set_defaults(func=pinCommand)
669 parser_update = subparsers.add_parser('update')
670 parser_update.add_argument('--dry-run', action='store_true')
671 parser_update.add_argument('--profile', default=(
672 '/nix/var/nix/profiles/per-user/%s/channels' % getpass.getuser()))
673 parser_update.add_argument('channels_file', type=str, nargs='+')
674 parser_update.set_defaults(func=updateCommand)
675 args = parser.parse_args()
676 args.func(args)
677
678
679 if __name__ == '__main__':
680 main()