]> git.scottworley.com Git - pinch/blob - pinch.py
f4fd0b5e9ed0739d2e323da9fad8580473c46ddf
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 Mapping,
26 NamedTuple,
27 NewType,
28 Optional,
29 Tuple,
30 Type,
31 Union,
32 )
33
34 # Use xdg module when it's less painful to have as a dependency
35
36
37 class XDG(types.SimpleNamespace):
38 XDG_CACHE_HOME: str
39
40
41 xdg = XDG(
42 XDG_CACHE_HOME=os.getenv(
43 'XDG_CACHE_HOME',
44 os.path.expanduser('~/.cache')))
45
46
47 class VerificationError(Exception):
48 pass
49
50
51 class Verification:
52
53 def __init__(self) -> None:
54 self.line_length = 0
55
56 def status(self, s: str) -> None:
57 print(s, end=' ', file=sys.stderr, flush=True)
58 self.line_length += 1 + len(s) # Unicode??
59
60 @staticmethod
61 def _color(s: str, c: int) -> str:
62 return '\033[%2dm%s\033[00m' % (c, s)
63
64 def result(self, r: bool) -> None:
65 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
66 length = len(message)
67 cols = shutil.get_terminal_size().columns or 80
68 pad = (cols - (self.line_length + length)) % cols
69 print(' ' * pad + self._color(message, color), file=sys.stderr)
70 self.line_length = 0
71 if not r:
72 raise VerificationError()
73
74 def check(self, s: str, r: bool) -> None:
75 self.status(s)
76 self.result(r)
77
78 def ok(self) -> None:
79 self.result(True)
80
81
82 Digest16 = NewType('Digest16', str)
83 Digest32 = NewType('Digest32', str)
84
85
86 class ChannelTableEntry(types.SimpleNamespace):
87 absolute_url: str
88 digest: Digest16
89 file: str
90 size: int
91 url: str
92
93
94 class AliasPin(NamedTuple):
95 pass
96
97
98 class GitPin(NamedTuple):
99 git_revision: str
100 release_name: str
101
102
103 class ChannelPin(NamedTuple):
104 git_revision: str
105 release_name: str
106 tarball_url: str
107 tarball_sha256: str
108
109
110 Pin = Union[AliasPin, GitPin, ChannelPin]
111
112
113 class SearchPath(types.SimpleNamespace, ABC):
114
115 @abstractmethod
116 def pin(self, v: Verification) -> Pin:
117 pass
118
119
120 class AliasSearchPath(SearchPath):
121 alias_of: str
122
123 def pin(self, v: Verification) -> AliasPin:
124 assert not hasattr(self, 'git_repo')
125 return AliasPin()
126
127
128 # (This lint-disable is for pylint bug https://github.com/PyCQA/pylint/issues/179
129 # which is fixed in pylint 2.5.)
130 class TarrableSearchPath(SearchPath, ABC): # pylint: disable=abstract-method
131 channel_html: bytes
132 channel_url: str
133 forwarded_url: str
134 git_ref: str
135 git_repo: str
136 table: Dict[str, ChannelTableEntry]
137
138
139 class GitSearchPath(TarrableSearchPath):
140 def pin(self, v: Verification) -> GitPin:
141 old_revision = (
142 self.git_revision if hasattr(self, 'git_revision') else None)
143 if hasattr(self, 'git_revision'):
144 del self.git_revision
145
146 new_revision = git_fetch(v, self, None, old_revision)
147 return GitPin(release_name=git_revision_name(v, self, new_revision),
148 git_revision=new_revision)
149
150 def fetch(self, v: Verification, section: str,
151 conf: configparser.SectionProxy) -> str:
152 if 'git_revision' not in conf or 'release_name' not in conf:
153 raise Exception(
154 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
155 section)
156 the_pin = GitPin(
157 release_name=conf['release_name'],
158 git_revision=conf['git_revision'])
159
160 ensure_git_rev_available(v, self, the_pin, None)
161 return git_get_tarball(v, self, the_pin)
162
163
164 class ChannelSearchPath(TarrableSearchPath):
165 def pin(self, v: Verification) -> ChannelPin:
166 old_revision = (
167 self.git_revision if hasattr(self, 'git_revision') else None)
168 if hasattr(self, 'git_revision'):
169 del self.git_revision
170
171 fetch(v, self)
172 new_gitpin = parse_channel(v, self)
173 fetch_resources(v, self, new_gitpin)
174 ensure_git_rev_available(v, self, new_gitpin, old_revision)
175 check_channel_contents(v, self, new_gitpin)
176 return ChannelPin(
177 release_name=new_gitpin.release_name,
178 tarball_url=self.table['nixexprs.tar.xz'].absolute_url,
179 tarball_sha256=self.table['nixexprs.tar.xz'].digest,
180 git_revision=new_gitpin.git_revision)
181
182 # Lint TODO: Put tarball_url and tarball_sha256 in ChannelSearchPath
183 # pylint: disable=no-self-use
184 def fetch(self, v: Verification, section: str,
185 conf: configparser.SectionProxy) -> str:
186 if 'git_repo' not in conf or 'release_name' not in conf:
187 raise Exception(
188 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
189 section)
190
191 return fetch_with_nix_prefetch_url(
192 v, conf['tarball_url'], Digest16(
193 conf['tarball_sha256']))
194
195
196 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
197
198 def throw(error: OSError) -> None:
199 raise error
200
201 def join(x: str, y: str) -> str:
202 return y if x == '.' else os.path.join(x, y)
203
204 def recursive_files(d: str) -> Iterable[str]:
205 all_files: List[str] = []
206 for path, dirs, files in os.walk(d, onerror=throw):
207 rel = os.path.relpath(path, start=d)
208 all_files.extend(join(rel, f) for f in files)
209 for dir_or_link in dirs:
210 if os.path.islink(join(path, dir_or_link)):
211 all_files.append(join(rel, dir_or_link))
212 return all_files
213
214 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
215 return (f for f in files if not f.startswith('.git/'))
216
217 files = functools.reduce(
218 operator.or_, (set(
219 exclude_dot_git(
220 recursive_files(x))) for x in [a, b]))
221 return filecmp.cmpfiles(a, b, files, shallow=False)
222
223
224 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
225 v.status('Fetching channel')
226 request = urllib.request.urlopen(channel.channel_url, timeout=10)
227 channel.channel_html = request.read()
228 channel.forwarded_url = request.geturl()
229 v.result(request.status == 200) # type: ignore # (for old mypy)
230 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
231
232
233 def parse_channel(v: Verification, channel: TarrableSearchPath) -> GitPin:
234 v.status('Parsing channel description as XML')
235 d = xml.dom.minidom.parseString(channel.channel_html)
236 v.ok()
237
238 v.status('Extracting release name:')
239 title_name = d.getElementsByTagName(
240 'title')[0].firstChild.nodeValue.split()[2]
241 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
242 v.status(title_name)
243 v.result(title_name == h1_name)
244
245 v.status('Extracting git commit:')
246 git_commit_node = d.getElementsByTagName('tt')[0]
247 git_revision = git_commit_node.firstChild.nodeValue
248 v.status(git_revision)
249 v.ok()
250 v.status('Verifying git commit label')
251 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
252
253 v.status('Parsing table')
254 channel.table = {}
255 for row in d.getElementsByTagName('tr')[1:]:
256 name = row.childNodes[0].firstChild.firstChild.nodeValue
257 url = row.childNodes[0].firstChild.getAttribute('href')
258 size = int(row.childNodes[1].firstChild.nodeValue)
259 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
260 channel.table[name] = ChannelTableEntry(
261 url=url, digest=digest, size=size)
262 v.ok()
263 return GitPin(release_name=title_name, git_revision=git_revision)
264
265
266 def digest_string(s: bytes) -> Digest16:
267 return Digest16(hashlib.sha256(s).hexdigest())
268
269
270 def digest_file(filename: str) -> Digest16:
271 hasher = hashlib.sha256()
272 with open(filename, 'rb') as f:
273 # pylint: disable=cell-var-from-loop
274 for block in iter(lambda: f.read(4096), b''):
275 hasher.update(block)
276 return Digest16(hasher.hexdigest())
277
278
279 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
280 v.status('Converting digest to base16')
281 process = subprocess.run(
282 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
283 v.result(process.returncode == 0)
284 return Digest16(process.stdout.decode().strip())
285
286
287 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
288 v.status('Converting digest to base32')
289 process = subprocess.run(
290 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
291 v.result(process.returncode == 0)
292 return Digest32(process.stdout.decode().strip())
293
294
295 def fetch_with_nix_prefetch_url(
296 v: Verification,
297 url: str,
298 digest: Digest16) -> str:
299 v.status('Fetching %s' % url)
300 process = subprocess.run(
301 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
302 v.result(process.returncode == 0)
303 prefetch_digest, path, empty = process.stdout.decode().split('\n')
304 assert empty == ''
305 v.check("Verifying nix-prefetch-url's digest",
306 to_Digest16(v, Digest32(prefetch_digest)) == digest)
307 v.status("Verifying file digest")
308 file_digest = digest_file(path)
309 v.result(file_digest == digest)
310 return path # type: ignore # (for old mypy)
311
312
313 def fetch_resources(
314 v: Verification,
315 channel: ChannelSearchPath,
316 pin: GitPin) -> None:
317 for resource in ['git-revision', 'nixexprs.tar.xz']:
318 fields = channel.table[resource]
319 fields.absolute_url = urllib.parse.urljoin(
320 channel.forwarded_url, fields.url)
321 fields.file = fetch_with_nix_prefetch_url(
322 v, fields.absolute_url, fields.digest)
323 v.status('Verifying git commit on main page matches git commit in table')
324 v.result(
325 open(
326 channel.table['git-revision'].file).read(999) == pin.git_revision)
327
328
329 def git_cachedir(git_repo: str) -> str:
330 return os.path.join(
331 xdg.XDG_CACHE_HOME,
332 'pinch/git',
333 digest_string(git_repo.encode()))
334
335
336 def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
337 return os.path.join(
338 xdg.XDG_CACHE_HOME,
339 'pinch/git-tarball',
340 '%s-%s-%s' %
341 (digest_string(channel.git_repo.encode()),
342 pin.git_revision,
343 pin.release_name))
344
345
346 def verify_git_ancestry(
347 v: Verification,
348 channel: TarrableSearchPath,
349 new_revision: str,
350 old_revision: Optional[str]) -> None:
351 cachedir = git_cachedir(channel.git_repo)
352 v.status('Verifying rev is an ancestor of ref')
353 process = subprocess.run(['git',
354 '-C',
355 cachedir,
356 'merge-base',
357 '--is-ancestor',
358 new_revision,
359 channel.git_ref])
360 v.result(process.returncode == 0)
361
362 if old_revision is not None:
363 v.status(
364 'Verifying rev is an ancestor of previous rev %s' %
365 old_revision)
366 process = subprocess.run(['git',
367 '-C',
368 cachedir,
369 'merge-base',
370 '--is-ancestor',
371 old_revision,
372 new_revision])
373 v.result(process.returncode == 0)
374
375
376 def git_fetch(
377 v: Verification,
378 channel: TarrableSearchPath,
379 desired_revision: Optional[str],
380 old_revision: Optional[str]) -> str:
381 # It would be nice if we could share the nix git cache, but as of the time
382 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
383 # yet), and trying to straddle them both is too far into nix implementation
384 # details for my comfort. So we re-implement here half of nix.fetchGit.
385 # :(
386
387 cachedir = git_cachedir(channel.git_repo)
388 if not os.path.exists(cachedir):
389 v.status("Initializing git repo")
390 process = subprocess.run(
391 ['git', 'init', '--bare', cachedir])
392 v.result(process.returncode == 0)
393
394 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
395 # We don't use --force here because we want to abort and freak out if forced
396 # updates are happening.
397 process = subprocess.run(['git',
398 '-C',
399 cachedir,
400 'fetch',
401 channel.git_repo,
402 '%s:%s' % (channel.git_ref,
403 channel.git_ref)])
404 v.result(process.returncode == 0)
405
406 if desired_revision is not None:
407 v.status('Verifying that fetch retrieved this rev')
408 process = subprocess.run(
409 ['git', '-C', cachedir, 'cat-file', '-e', desired_revision])
410 v.result(process.returncode == 0)
411
412 new_revision = open(
413 os.path.join(
414 cachedir,
415 'refs',
416 'heads',
417 channel.git_ref)).read(999).strip()
418
419 verify_git_ancestry(v, channel, new_revision, old_revision)
420
421 return new_revision
422
423
424 def ensure_git_rev_available(
425 v: Verification,
426 channel: TarrableSearchPath,
427 pin: GitPin,
428 old_revision: Optional[str]) -> None:
429 cachedir = git_cachedir(channel.git_repo)
430 if os.path.exists(cachedir):
431 v.status('Checking if we already have this rev:')
432 process = subprocess.run(
433 ['git', '-C', cachedir, 'cat-file', '-e', pin.git_revision])
434 if process.returncode == 0:
435 v.status('yes')
436 if process.returncode == 1:
437 v.status('no')
438 v.result(process.returncode == 0 or process.returncode == 1)
439 if process.returncode == 0:
440 verify_git_ancestry(v, channel, pin.git_revision, old_revision)
441 return
442 git_fetch(v, channel, pin.git_revision, old_revision)
443
444
445 def compare_tarball_and_git(
446 v: Verification,
447 pin: GitPin,
448 channel_contents: str,
449 git_contents: str) -> None:
450 v.status('Comparing channel tarball with git checkout')
451 match, mismatch, errors = compare(os.path.join(
452 channel_contents, pin.release_name), git_contents)
453 v.ok()
454 v.check('%d files match' % len(match), len(match) > 0)
455 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
456 expected_errors = [
457 '.git-revision',
458 '.version-suffix',
459 'nixpkgs',
460 'programs.sqlite',
461 'svn-revision']
462 benign_errors = []
463 for ee in expected_errors:
464 if ee in errors:
465 errors.remove(ee)
466 benign_errors.append(ee)
467 v.check(
468 '%d unexpected incomparable files' %
469 len(errors),
470 len(errors) == 0)
471 v.check(
472 '(%d of %d expected incomparable files)' %
473 (len(benign_errors),
474 len(expected_errors)),
475 len(benign_errors) == len(expected_errors))
476
477
478 def extract_tarball(
479 v: Verification,
480 channel: TarrableSearchPath,
481 dest: str) -> None:
482 v.status('Extracting tarball %s' %
483 channel.table['nixexprs.tar.xz'].file)
484 shutil.unpack_archive(
485 channel.table['nixexprs.tar.xz'].file,
486 dest)
487 v.ok()
488
489
490 def git_checkout(
491 v: Verification,
492 channel: TarrableSearchPath,
493 pin: GitPin,
494 dest: str) -> None:
495 v.status('Checking out corresponding git revision')
496 git = subprocess.Popen(['git',
497 '-C',
498 git_cachedir(channel.git_repo),
499 'archive',
500 pin.git_revision],
501 stdout=subprocess.PIPE)
502 tar = subprocess.Popen(
503 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
504 if git.stdout:
505 git.stdout.close()
506 tar.wait()
507 git.wait()
508 v.result(git.returncode == 0 and tar.returncode == 0)
509
510
511 def git_get_tarball(
512 v: Verification,
513 channel: TarrableSearchPath,
514 pin: GitPin) -> str:
515 cache_file = tarball_cache_file(channel, pin)
516 if os.path.exists(cache_file):
517 cached_tarball = open(cache_file).read(9999)
518 if os.path.exists(cached_tarball):
519 return cached_tarball
520
521 with tempfile.TemporaryDirectory() as output_dir:
522 output_filename = os.path.join(
523 output_dir, pin.release_name + '.tar.xz')
524 with open(output_filename, 'w') as output_file:
525 v.status(
526 'Generating tarball for git revision %s' %
527 pin.git_revision)
528 git = subprocess.Popen(['git',
529 '-C',
530 git_cachedir(channel.git_repo),
531 'archive',
532 '--prefix=%s/' % pin.release_name,
533 pin.git_revision],
534 stdout=subprocess.PIPE)
535 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
536 xz.wait()
537 git.wait()
538 v.result(git.returncode == 0 and xz.returncode == 0)
539
540 v.status('Putting tarball in Nix store')
541 process = subprocess.run(
542 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
543 v.result(process.returncode == 0)
544 store_tarball = process.stdout.decode().strip()
545
546 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
547 open(cache_file, 'w').write(store_tarball)
548 return store_tarball # type: ignore # (for old mypy)
549
550
551 def check_channel_metadata(
552 v: Verification,
553 pin: GitPin,
554 channel_contents: str) -> None:
555 v.status('Verifying git commit in channel tarball')
556 v.result(
557 open(
558 os.path.join(
559 channel_contents,
560 pin.release_name,
561 '.git-revision')).read(999) == pin.git_revision)
562
563 v.status(
564 'Verifying version-suffix is a suffix of release name %s:' %
565 pin.release_name)
566 version_suffix = open(
567 os.path.join(
568 channel_contents,
569 pin.release_name,
570 '.version-suffix')).read(999)
571 v.status(version_suffix)
572 v.result(pin.release_name.endswith(version_suffix))
573
574
575 def check_channel_contents(
576 v: Verification,
577 channel: TarrableSearchPath,
578 pin: GitPin) -> None:
579 with tempfile.TemporaryDirectory() as channel_contents, \
580 tempfile.TemporaryDirectory() as git_contents:
581
582 extract_tarball(v, channel, channel_contents)
583 check_channel_metadata(v, pin, channel_contents)
584
585 git_checkout(v, channel, pin, git_contents)
586
587 compare_tarball_and_git(v, pin, channel_contents, git_contents)
588
589 v.status('Removing temporary directories')
590 v.ok()
591
592
593 def git_revision_name(
594 v: Verification,
595 channel: TarrableSearchPath,
596 git_revision: str) -> str:
597 v.status('Getting commit date')
598 process = subprocess.run(['git',
599 '-C',
600 git_cachedir(channel.git_repo),
601 'log',
602 '-n1',
603 '--format=%ct-%h',
604 '--abbrev=11',
605 '--no-show-signature',
606 git_revision],
607 stdout=subprocess.PIPE)
608 v.result(process.returncode == 0 and process.stdout != b'')
609 return '%s-%s' % (os.path.basename(channel.git_repo),
610 process.stdout.decode().strip())
611
612
613 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
614 mapping: Mapping[str, Type[SearchPath]] = {
615 'alias': AliasSearchPath,
616 'channel': ChannelSearchPath,
617 'git': GitSearchPath,
618 }
619 return mapping[conf['type']](**dict(conf.items()))
620
621
622 def read_config(filename: str) -> configparser.ConfigParser:
623 config = configparser.ConfigParser()
624 config.read_file(open(filename), filename)
625 return config
626
627
628 def read_config_files(
629 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
630 merged_config: Dict[str, configparser.SectionProxy] = {}
631 for file in filenames:
632 config = read_config(file)
633 for section in config.sections():
634 if section in merged_config:
635 raise Exception('Duplicate channel "%s"' % section)
636 merged_config[section] = config[section]
637 return merged_config
638
639
640 def pinCommand(args: argparse.Namespace) -> None:
641 v = Verification()
642 config = read_config(args.channels_file)
643 for section in config.sections():
644 if args.channels and section not in args.channels:
645 continue
646
647 sp = read_search_path(config[section])
648
649 config[section].update(sp.pin(v)._asdict())
650
651 with open(args.channels_file, 'w') as configfile:
652 config.write(configfile)
653
654
655 def updateCommand(args: argparse.Namespace) -> None:
656 v = Verification()
657 exprs: Dict[str, str] = {}
658 config = read_config_files(args.channels_file)
659 for section in config:
660 sp = read_search_path(config[section])
661 if isinstance(sp, AliasSearchPath):
662 assert 'git_repo' not in config[section]
663 continue
664 tarball = sp.fetch(v, section, config[section])
665 exprs[section] = (
666 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
667 (config[section]['release_name'], tarball))
668
669 for section in config:
670 if 'alias_of' in config[section]:
671 exprs[section] = exprs[str(config[section]['alias_of'])]
672
673 command = [
674 'nix-env',
675 '--profile',
676 '/nix/var/nix/profiles/per-user/%s/channels' %
677 getpass.getuser(),
678 '--show-trace',
679 '--file',
680 '<nix/unpack-channel.nix>',
681 '--install',
682 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
683 if args.dry_run:
684 print(' '.join(map(shlex.quote, command)))
685 else:
686 v.status('Installing channels with nix-env')
687 process = subprocess.run(command)
688 v.result(process.returncode == 0)
689
690
691 def main() -> None:
692 parser = argparse.ArgumentParser(prog='pinch')
693 subparsers = parser.add_subparsers(dest='mode', required=True)
694 parser_pin = subparsers.add_parser('pin')
695 parser_pin.add_argument('channels_file', type=str)
696 parser_pin.add_argument('channels', type=str, nargs='*')
697 parser_pin.set_defaults(func=pinCommand)
698 parser_update = subparsers.add_parser('update')
699 parser_update.add_argument('--dry-run', action='store_true')
700 parser_update.add_argument('channels_file', type=str, nargs='+')
701 parser_update.set_defaults(func=updateCommand)
702 args = parser.parse_args()
703 args.func(args)
704
705
706 if __name__ == '__main__':
707 main()