]> git.scottworley.com Git - pinch/blob - pinch.py
1b11709901e79825fe98a3e9bb10699be374db9d
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 Mapping,
26 NamedTuple,
27 NewType,
28 Optional,
29 Tuple,
30 Type,
31 Union,
32 )
33
34 # Use xdg module when it's less painful to have as a dependency
35
36
37 class XDG(types.SimpleNamespace):
38 XDG_CACHE_HOME: str
39
40
41 xdg = XDG(
42 XDG_CACHE_HOME=os.getenv(
43 'XDG_CACHE_HOME',
44 os.path.expanduser('~/.cache')))
45
46
47 class VerificationError(Exception):
48 pass
49
50
51 class Verification:
52
53 def __init__(self) -> None:
54 self.line_length = 0
55
56 def status(self, s: str) -> None:
57 print(s, end=' ', file=sys.stderr, flush=True)
58 self.line_length += 1 + len(s) # Unicode??
59
60 @staticmethod
61 def _color(s: str, c: int) -> str:
62 return '\033[%2dm%s\033[00m' % (c, s)
63
64 def result(self, r: bool) -> None:
65 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
66 length = len(message)
67 cols = shutil.get_terminal_size().columns or 80
68 pad = (cols - (self.line_length + length)) % cols
69 print(' ' * pad + self._color(message, color), file=sys.stderr)
70 self.line_length = 0
71 if not r:
72 raise VerificationError()
73
74 def check(self, s: str, r: bool) -> None:
75 self.status(s)
76 self.result(r)
77
78 def ok(self) -> None:
79 self.result(True)
80
81
82 Digest16 = NewType('Digest16', str)
83 Digest32 = NewType('Digest32', str)
84
85
86 class ChannelTableEntry(types.SimpleNamespace):
87 absolute_url: str
88 digest: Digest16
89 file: str
90 size: int
91 url: str
92
93
94 class AliasPin(NamedTuple):
95 pass
96
97
98 class GitPin(NamedTuple):
99 git_revision: str
100 release_name: str
101
102
103 class ChannelPin(NamedTuple):
104 git_revision: str
105 release_name: str
106 tarball_url: str
107 tarball_sha256: str
108
109
110 Pin = Union[AliasPin, GitPin, ChannelPin]
111
112
113 class SearchPath(types.SimpleNamespace, ABC):
114
115 @abstractmethod
116 def pin(self, v: Verification) -> Pin:
117 pass
118
119
120 class AliasSearchPath(SearchPath):
121 alias_of: str
122
123 def pin(self, v: Verification) -> AliasPin:
124 assert not hasattr(self, 'git_repo')
125 return AliasPin()
126
127
128 # (This lint-disable is for pylint bug https://github.com/PyCQA/pylint/issues/179
129 # which is fixed in pylint 2.5.)
130 class TarrableSearchPath(SearchPath, ABC): # pylint: disable=abstract-method
131 channel_html: bytes
132 channel_url: str
133 forwarded_url: str
134 git_ref: str
135 git_repo: str
136 table: Dict[str, ChannelTableEntry]
137
138
139 class GitSearchPath(TarrableSearchPath):
140 def pin(self, v: Verification) -> GitPin:
141 old_revision = (
142 self.git_revision if hasattr(self, 'git_revision') else None)
143 if hasattr(self, 'git_revision'):
144 del self.git_revision
145
146 new_revision = git_fetch(v, self, None, old_revision)
147 return GitPin(release_name=git_revision_name(v, self, new_revision),
148 git_revision=new_revision)
149
150 def fetch(self, v: Verification, section: str,
151 conf: configparser.SectionProxy) -> str:
152 if 'git_revision' not in conf or 'release_name' not in conf:
153 raise Exception(
154 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
155 section)
156 the_pin = GitPin(
157 release_name=conf['release_name'],
158 git_revision=conf['git_revision'])
159
160 ensure_git_rev_available(v, self, the_pin, None)
161 return git_get_tarball(v, self, the_pin)
162
163
164 class ChannelSearchPath(TarrableSearchPath):
165 def pin(self, v: Verification) -> ChannelPin:
166 old_revision = (
167 self.git_revision if hasattr(self, 'git_revision') else None)
168 if hasattr(self, 'git_revision'):
169 del self.git_revision
170
171 fetch(v, self)
172 new_gitpin = parse_channel(v, self)
173 fetch_resources(v, self, new_gitpin)
174 ensure_git_rev_available(v, self, new_gitpin, old_revision)
175 check_channel_contents(v, self)
176 return ChannelPin(
177 release_name=self.release_name,
178 tarball_url=self.table['nixexprs.tar.xz'].absolute_url,
179 tarball_sha256=self.table['nixexprs.tar.xz'].digest,
180 git_revision=self.git_revision)
181
182 # Lint TODO: Put tarball_url and tarball_sha256 in ChannelSearchPath
183 # pylint: disable=no-self-use
184 def fetch(self, v: Verification, section: str,
185 conf: configparser.SectionProxy) -> str:
186 if 'git_repo' not in conf or 'release_name' not in conf:
187 raise Exception(
188 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
189 section)
190
191 return fetch_with_nix_prefetch_url(
192 v, conf['tarball_url'], Digest16(
193 conf['tarball_sha256']))
194
195
196 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
197
198 def throw(error: OSError) -> None:
199 raise error
200
201 def join(x: str, y: str) -> str:
202 return y if x == '.' else os.path.join(x, y)
203
204 def recursive_files(d: str) -> Iterable[str]:
205 all_files: List[str] = []
206 for path, dirs, files in os.walk(d, onerror=throw):
207 rel = os.path.relpath(path, start=d)
208 all_files.extend(join(rel, f) for f in files)
209 for dir_or_link in dirs:
210 if os.path.islink(join(path, dir_or_link)):
211 all_files.append(join(rel, dir_or_link))
212 return all_files
213
214 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
215 return (f for f in files if not f.startswith('.git/'))
216
217 files = functools.reduce(
218 operator.or_, (set(
219 exclude_dot_git(
220 recursive_files(x))) for x in [a, b]))
221 return filecmp.cmpfiles(a, b, files, shallow=False)
222
223
224 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
225 v.status('Fetching channel')
226 request = urllib.request.urlopen(channel.channel_url, timeout=10)
227 channel.channel_html = request.read()
228 channel.forwarded_url = request.geturl()
229 v.result(request.status == 200) # type: ignore # (for old mypy)
230 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
231
232
233 def parse_channel(v: Verification, channel: TarrableSearchPath) -> GitPin:
234 v.status('Parsing channel description as XML')
235 d = xml.dom.minidom.parseString(channel.channel_html)
236 v.ok()
237
238 v.status('Extracting release name:')
239 title_name = d.getElementsByTagName(
240 'title')[0].firstChild.nodeValue.split()[2]
241 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
242 v.status(title_name)
243 v.result(title_name == h1_name)
244 channel.release_name = title_name
245
246 v.status('Extracting git commit:')
247 git_commit_node = d.getElementsByTagName('tt')[0]
248 channel.git_revision = git_commit_node.firstChild.nodeValue
249 v.status(channel.git_revision)
250 v.ok()
251 v.status('Verifying git commit label')
252 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
253
254 v.status('Parsing table')
255 channel.table = {}
256 for row in d.getElementsByTagName('tr')[1:]:
257 name = row.childNodes[0].firstChild.firstChild.nodeValue
258 url = row.childNodes[0].firstChild.getAttribute('href')
259 size = int(row.childNodes[1].firstChild.nodeValue)
260 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
261 channel.table[name] = ChannelTableEntry(
262 url=url, digest=digest, size=size)
263 v.ok()
264 return GitPin(release_name=title_name, git_revision=channel.git_revision)
265
266
267 def digest_string(s: bytes) -> Digest16:
268 return Digest16(hashlib.sha256(s).hexdigest())
269
270
271 def digest_file(filename: str) -> Digest16:
272 hasher = hashlib.sha256()
273 with open(filename, 'rb') as f:
274 # pylint: disable=cell-var-from-loop
275 for block in iter(lambda: f.read(4096), b''):
276 hasher.update(block)
277 return Digest16(hasher.hexdigest())
278
279
280 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
281 v.status('Converting digest to base16')
282 process = subprocess.run(
283 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
284 v.result(process.returncode == 0)
285 return Digest16(process.stdout.decode().strip())
286
287
288 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
289 v.status('Converting digest to base32')
290 process = subprocess.run(
291 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
292 v.result(process.returncode == 0)
293 return Digest32(process.stdout.decode().strip())
294
295
296 def fetch_with_nix_prefetch_url(
297 v: Verification,
298 url: str,
299 digest: Digest16) -> str:
300 v.status('Fetching %s' % url)
301 process = subprocess.run(
302 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
303 v.result(process.returncode == 0)
304 prefetch_digest, path, empty = process.stdout.decode().split('\n')
305 assert empty == ''
306 v.check("Verifying nix-prefetch-url's digest",
307 to_Digest16(v, Digest32(prefetch_digest)) == digest)
308 v.status("Verifying file digest")
309 file_digest = digest_file(path)
310 v.result(file_digest == digest)
311 return path # type: ignore # (for old mypy)
312
313
314 def fetch_resources(
315 v: Verification,
316 channel: ChannelSearchPath,
317 pin: GitPin) -> None:
318 for resource in ['git-revision', 'nixexprs.tar.xz']:
319 fields = channel.table[resource]
320 fields.absolute_url = urllib.parse.urljoin(
321 channel.forwarded_url, fields.url)
322 fields.file = fetch_with_nix_prefetch_url(
323 v, fields.absolute_url, fields.digest)
324 v.status('Verifying git commit on main page matches git commit in table')
325 v.result(
326 open(
327 channel.table['git-revision'].file).read(999) == pin.git_revision)
328
329
330 def git_cachedir(git_repo: str) -> str:
331 return os.path.join(
332 xdg.XDG_CACHE_HOME,
333 'pinch/git',
334 digest_string(git_repo.encode()))
335
336
337 def tarball_cache_file(channel: TarrableSearchPath) -> str:
338 return os.path.join(
339 xdg.XDG_CACHE_HOME,
340 'pinch/git-tarball',
341 '%s-%s-%s' %
342 (digest_string(channel.git_repo.encode()),
343 channel.git_revision,
344 channel.release_name))
345
346
347 def verify_git_ancestry(
348 v: Verification,
349 channel: TarrableSearchPath,
350 new_revision: str,
351 old_revision: Optional[str]) -> None:
352 cachedir = git_cachedir(channel.git_repo)
353 v.status('Verifying rev is an ancestor of ref')
354 process = subprocess.run(['git',
355 '-C',
356 cachedir,
357 'merge-base',
358 '--is-ancestor',
359 new_revision,
360 channel.git_ref])
361 v.result(process.returncode == 0)
362
363 if old_revision is not None:
364 v.status(
365 'Verifying rev is an ancestor of previous rev %s' %
366 old_revision)
367 process = subprocess.run(['git',
368 '-C',
369 cachedir,
370 'merge-base',
371 '--is-ancestor',
372 old_revision,
373 new_revision])
374 v.result(process.returncode == 0)
375
376
377 def git_fetch(
378 v: Verification,
379 channel: TarrableSearchPath,
380 desired_revision: Optional[str],
381 old_revision: Optional[str]) -> str:
382 # It would be nice if we could share the nix git cache, but as of the time
383 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
384 # yet), and trying to straddle them both is too far into nix implementation
385 # details for my comfort. So we re-implement here half of nix.fetchGit.
386 # :(
387
388 cachedir = git_cachedir(channel.git_repo)
389 if not os.path.exists(cachedir):
390 v.status("Initializing git repo")
391 process = subprocess.run(
392 ['git', 'init', '--bare', cachedir])
393 v.result(process.returncode == 0)
394
395 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
396 # We don't use --force here because we want to abort and freak out if forced
397 # updates are happening.
398 process = subprocess.run(['git',
399 '-C',
400 cachedir,
401 'fetch',
402 channel.git_repo,
403 '%s:%s' % (channel.git_ref,
404 channel.git_ref)])
405 v.result(process.returncode == 0)
406
407 if desired_revision is not None:
408 v.status('Verifying that fetch retrieved this rev')
409 process = subprocess.run(
410 ['git', '-C', cachedir, 'cat-file', '-e', desired_revision])
411 v.result(process.returncode == 0)
412
413 new_revision = open(
414 os.path.join(
415 cachedir,
416 'refs',
417 'heads',
418 channel.git_ref)).read(999).strip()
419
420 verify_git_ancestry(v, channel, new_revision, old_revision)
421
422 return new_revision
423
424
425 def ensure_git_rev_available(
426 v: Verification,
427 channel: TarrableSearchPath,
428 pin: GitPin,
429 old_revision: Optional[str]) -> None:
430 cachedir = git_cachedir(channel.git_repo)
431 if os.path.exists(cachedir):
432 v.status('Checking if we already have this rev:')
433 process = subprocess.run(
434 ['git', '-C', cachedir, 'cat-file', '-e', pin.git_revision])
435 if process.returncode == 0:
436 v.status('yes')
437 if process.returncode == 1:
438 v.status('no')
439 v.result(process.returncode == 0 or process.returncode == 1)
440 if process.returncode == 0:
441 verify_git_ancestry(v, channel, pin.git_revision, old_revision)
442 return
443 git_fetch(v, channel, pin.git_revision, old_revision)
444
445
446 def compare_tarball_and_git(
447 v: Verification,
448 channel: TarrableSearchPath,
449 channel_contents: str,
450 git_contents: str) -> None:
451 v.status('Comparing channel tarball with git checkout')
452 match, mismatch, errors = compare(os.path.join(
453 channel_contents, channel.release_name), git_contents)
454 v.ok()
455 v.check('%d files match' % len(match), len(match) > 0)
456 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
457 expected_errors = [
458 '.git-revision',
459 '.version-suffix',
460 'nixpkgs',
461 'programs.sqlite',
462 'svn-revision']
463 benign_errors = []
464 for ee in expected_errors:
465 if ee in errors:
466 errors.remove(ee)
467 benign_errors.append(ee)
468 v.check(
469 '%d unexpected incomparable files' %
470 len(errors),
471 len(errors) == 0)
472 v.check(
473 '(%d of %d expected incomparable files)' %
474 (len(benign_errors),
475 len(expected_errors)),
476 len(benign_errors) == len(expected_errors))
477
478
479 def extract_tarball(
480 v: Verification,
481 channel: TarrableSearchPath,
482 dest: str) -> None:
483 v.status('Extracting tarball %s' %
484 channel.table['nixexprs.tar.xz'].file)
485 shutil.unpack_archive(
486 channel.table['nixexprs.tar.xz'].file,
487 dest)
488 v.ok()
489
490
491 def git_checkout(
492 v: Verification,
493 channel: TarrableSearchPath,
494 dest: str) -> None:
495 v.status('Checking out corresponding git revision')
496 git = subprocess.Popen(['git',
497 '-C',
498 git_cachedir(channel.git_repo),
499 'archive',
500 channel.git_revision],
501 stdout=subprocess.PIPE)
502 tar = subprocess.Popen(
503 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
504 if git.stdout:
505 git.stdout.close()
506 tar.wait()
507 git.wait()
508 v.result(git.returncode == 0 and tar.returncode == 0)
509
510
511 def git_get_tarball(
512 v: Verification,
513 channel: TarrableSearchPath,
514 pin: GitPin) -> str:
515 cache_file = tarball_cache_file(channel)
516 if os.path.exists(cache_file):
517 cached_tarball = open(cache_file).read(9999)
518 if os.path.exists(cached_tarball):
519 return cached_tarball
520
521 with tempfile.TemporaryDirectory() as output_dir:
522 output_filename = os.path.join(
523 output_dir, pin.release_name + '.tar.xz')
524 with open(output_filename, 'w') as output_file:
525 v.status(
526 'Generating tarball for git revision %s' %
527 pin.git_revision)
528 git = subprocess.Popen(['git',
529 '-C',
530 git_cachedir(channel.git_repo),
531 'archive',
532 '--prefix=%s/' % pin.release_name,
533 pin.git_revision],
534 stdout=subprocess.PIPE)
535 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
536 xz.wait()
537 git.wait()
538 v.result(git.returncode == 0 and xz.returncode == 0)
539
540 v.status('Putting tarball in Nix store')
541 process = subprocess.run(
542 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
543 v.result(process.returncode == 0)
544 store_tarball = process.stdout.decode().strip()
545
546 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
547 open(cache_file, 'w').write(store_tarball)
548 return store_tarball # type: ignore # (for old mypy)
549
550
551 def check_channel_metadata(
552 v: Verification,
553 channel: TarrableSearchPath,
554 channel_contents: str) -> None:
555 v.status('Verifying git commit in channel tarball')
556 v.result(
557 open(
558 os.path.join(
559 channel_contents,
560 channel.release_name,
561 '.git-revision')).read(999) == channel.git_revision)
562
563 v.status(
564 'Verifying version-suffix is a suffix of release name %s:' %
565 channel.release_name)
566 version_suffix = open(
567 os.path.join(
568 channel_contents,
569 channel.release_name,
570 '.version-suffix')).read(999)
571 v.status(version_suffix)
572 v.result(channel.release_name.endswith(version_suffix))
573
574
575 def check_channel_contents(
576 v: Verification,
577 channel: TarrableSearchPath) -> None:
578 with tempfile.TemporaryDirectory() as channel_contents, \
579 tempfile.TemporaryDirectory() as git_contents:
580
581 extract_tarball(v, channel, channel_contents)
582 check_channel_metadata(v, channel, channel_contents)
583
584 git_checkout(v, channel, git_contents)
585
586 compare_tarball_and_git(v, channel, channel_contents, git_contents)
587
588 v.status('Removing temporary directories')
589 v.ok()
590
591
592 def git_revision_name(
593 v: Verification,
594 channel: TarrableSearchPath,
595 git_revision: str) -> str:
596 v.status('Getting commit date')
597 process = subprocess.run(['git',
598 '-C',
599 git_cachedir(channel.git_repo),
600 'log',
601 '-n1',
602 '--format=%ct-%h',
603 '--abbrev=11',
604 '--no-show-signature',
605 git_revision],
606 stdout=subprocess.PIPE)
607 v.result(process.returncode == 0 and process.stdout != b'')
608 return '%s-%s' % (os.path.basename(channel.git_repo),
609 process.stdout.decode().strip())
610
611
612 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
613 mapping: Mapping[str, Type[SearchPath]] = {
614 'alias': AliasSearchPath,
615 'channel': ChannelSearchPath,
616 'git': GitSearchPath,
617 }
618 return mapping[conf['type']](**dict(conf.items()))
619
620
621 def read_config(filename: str) -> configparser.ConfigParser:
622 config = configparser.ConfigParser()
623 config.read_file(open(filename), filename)
624 return config
625
626
627 def read_config_files(
628 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
629 merged_config: Dict[str, configparser.SectionProxy] = {}
630 for file in filenames:
631 config = read_config(file)
632 for section in config.sections():
633 if section in merged_config:
634 raise Exception('Duplicate channel "%s"' % section)
635 merged_config[section] = config[section]
636 return merged_config
637
638
639 def pinCommand(args: argparse.Namespace) -> None:
640 v = Verification()
641 config = read_config(args.channels_file)
642 for section in config.sections():
643 if args.channels and section not in args.channels:
644 continue
645
646 sp = read_search_path(config[section])
647
648 config[section].update(sp.pin(v)._asdict())
649
650 with open(args.channels_file, 'w') as configfile:
651 config.write(configfile)
652
653
654 def updateCommand(args: argparse.Namespace) -> None:
655 v = Verification()
656 exprs: Dict[str, str] = {}
657 config = read_config_files(args.channels_file)
658 for section in config:
659 sp = read_search_path(config[section])
660 if isinstance(sp, AliasSearchPath):
661 assert 'git_repo' not in config[section]
662 continue
663 tarball = sp.fetch(v, section, config[section])
664 exprs[section] = (
665 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
666 (config[section]['release_name'], tarball))
667
668 for section in config:
669 if 'alias_of' in config[section]:
670 exprs[section] = exprs[str(config[section]['alias_of'])]
671
672 command = [
673 'nix-env',
674 '--profile',
675 '/nix/var/nix/profiles/per-user/%s/channels' %
676 getpass.getuser(),
677 '--show-trace',
678 '--file',
679 '<nix/unpack-channel.nix>',
680 '--install',
681 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
682 if args.dry_run:
683 print(' '.join(map(shlex.quote, command)))
684 else:
685 v.status('Installing channels with nix-env')
686 process = subprocess.run(command)
687 v.result(process.returncode == 0)
688
689
690 def main() -> None:
691 parser = argparse.ArgumentParser(prog='pinch')
692 subparsers = parser.add_subparsers(dest='mode', required=True)
693 parser_pin = subparsers.add_parser('pin')
694 parser_pin.add_argument('channels_file', type=str)
695 parser_pin.add_argument('channels', type=str, nargs='*')
696 parser_pin.set_defaults(func=pinCommand)
697 parser_update = subparsers.add_parser('update')
698 parser_update.add_argument('--dry-run', action='store_true')
699 parser_update.add_argument('channels_file', type=str, nargs='+')
700 parser_update.set_defaults(func=updateCommand)
701 args = parser.parse_args()
702 args.func(args)
703
704
705 if __name__ == '__main__':
706 main()