]> git.scottworley.com Git - pinch/blob - pinch.py
5e6aab9f6dcaab63bcbd936069b51654a7cd6182
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 Mapping,
26 NamedTuple,
27 NewType,
28 Tuple,
29 Type,
30 Union,
31 )
32
33 # Use xdg module when it's less painful to have as a dependency
34
35
36 class XDG(types.SimpleNamespace):
37 XDG_CACHE_HOME: str
38
39
40 xdg = XDG(
41 XDG_CACHE_HOME=os.getenv(
42 'XDG_CACHE_HOME',
43 os.path.expanduser('~/.cache')))
44
45
46 class VerificationError(Exception):
47 pass
48
49
50 class Verification:
51
52 def __init__(self) -> None:
53 self.line_length = 0
54
55 def status(self, s: str) -> None:
56 print(s, end=' ', file=sys.stderr, flush=True)
57 self.line_length += 1 + len(s) # Unicode??
58
59 @staticmethod
60 def _color(s: str, c: int) -> str:
61 return '\033[%2dm%s\033[00m' % (c, s)
62
63 def result(self, r: bool) -> None:
64 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
65 length = len(message)
66 cols = shutil.get_terminal_size().columns or 80
67 pad = (cols - (self.line_length + length)) % cols
68 print(' ' * pad + self._color(message, color), file=sys.stderr)
69 self.line_length = 0
70 if not r:
71 raise VerificationError()
72
73 def check(self, s: str, r: bool) -> None:
74 self.status(s)
75 self.result(r)
76
77 def ok(self) -> None:
78 self.result(True)
79
80
81 Digest16 = NewType('Digest16', str)
82 Digest32 = NewType('Digest32', str)
83
84
85 class ChannelTableEntry(types.SimpleNamespace):
86 absolute_url: str
87 digest: Digest16
88 file: str
89 size: int
90 url: str
91
92
93 class AliasPin(NamedTuple):
94 pass
95
96
97 class GitPin(NamedTuple):
98 git_revision: str
99 release_name: str
100
101
102 class ChannelPin(NamedTuple):
103 git_revision: str
104 release_name: str
105 tarball_url: str
106 tarball_sha256: str
107
108
109 Pin = Union[AliasPin, GitPin, ChannelPin]
110
111
112 class SearchPath(types.SimpleNamespace, ABC):
113
114 @abstractmethod
115 def pin(self, v: Verification) -> Pin:
116 pass
117
118
119 class AliasSearchPath(SearchPath):
120 alias_of: str
121
122 def pin(self, v: Verification) -> AliasPin:
123 assert not hasattr(self, 'git_repo')
124 return AliasPin()
125
126
127 # (This lint-disable is for pylint bug https://github.com/PyCQA/pylint/issues/179
128 # which is fixed in pylint 2.5.)
129 class TarrableSearchPath(SearchPath, ABC): # pylint: disable=abstract-method
130 channel_html: bytes
131 channel_url: str
132 forwarded_url: str
133 git_ref: str
134 git_repo: str
135 old_git_revision: str
136 table: Dict[str, ChannelTableEntry]
137
138
139 class GitSearchPath(TarrableSearchPath):
140 def pin(self, v: Verification) -> GitPin:
141 if hasattr(self, 'git_revision'):
142 self.old_git_revision = self.git_revision
143 del self.git_revision
144
145 git_fetch(v, self)
146 return GitPin(release_name=git_revision_name(v, self),
147 git_revision=self.git_revision)
148
149 def fetch(self, v: Verification, section: str,
150 conf: configparser.SectionProxy) -> str:
151 if 'git_revision' not in conf or 'release_name' not in conf:
152 raise Exception(
153 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
154 section)
155 the_pin = GitPin(
156 release_name=conf['release_name'],
157 git_revision=conf['git_revision'])
158
159 ensure_git_rev_available(v, self, the_pin)
160 return git_get_tarball(v, self, the_pin)
161
162
163 class ChannelSearchPath(TarrableSearchPath):
164 def pin(self, v: Verification) -> ChannelPin:
165 if hasattr(self, 'git_revision'):
166 self.old_git_revision = self.git_revision
167 del self.git_revision
168
169 fetch(v, self)
170 new_gitpin = parse_channel(v, self)
171 fetch_resources(v, self, new_gitpin)
172 ensure_git_rev_available(v, self, new_gitpin)
173 check_channel_contents(v, self)
174 return ChannelPin(
175 release_name=self.release_name,
176 tarball_url=self.table['nixexprs.tar.xz'].absolute_url,
177 tarball_sha256=self.table['nixexprs.tar.xz'].digest,
178 git_revision=self.git_revision)
179
180 # Lint TODO: Put tarball_url and tarball_sha256 in ChannelSearchPath
181 # pylint: disable=no-self-use
182 def fetch(self, v: Verification, section: str,
183 conf: configparser.SectionProxy) -> str:
184 if 'git_repo' not in conf or 'release_name' not in conf:
185 raise Exception(
186 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
187 section)
188
189 return fetch_with_nix_prefetch_url(
190 v, conf['tarball_url'], Digest16(
191 conf['tarball_sha256']))
192
193
194 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
195
196 def throw(error: OSError) -> None:
197 raise error
198
199 def join(x: str, y: str) -> str:
200 return y if x == '.' else os.path.join(x, y)
201
202 def recursive_files(d: str) -> Iterable[str]:
203 all_files: List[str] = []
204 for path, dirs, files in os.walk(d, onerror=throw):
205 rel = os.path.relpath(path, start=d)
206 all_files.extend(join(rel, f) for f in files)
207 for dir_or_link in dirs:
208 if os.path.islink(join(path, dir_or_link)):
209 all_files.append(join(rel, dir_or_link))
210 return all_files
211
212 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
213 return (f for f in files if not f.startswith('.git/'))
214
215 files = functools.reduce(
216 operator.or_, (set(
217 exclude_dot_git(
218 recursive_files(x))) for x in [a, b]))
219 return filecmp.cmpfiles(a, b, files, shallow=False)
220
221
222 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
223 v.status('Fetching channel')
224 request = urllib.request.urlopen(channel.channel_url, timeout=10)
225 channel.channel_html = request.read()
226 channel.forwarded_url = request.geturl()
227 v.result(request.status == 200) # type: ignore # (for old mypy)
228 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
229
230
231 def parse_channel(v: Verification, channel: TarrableSearchPath) -> GitPin:
232 v.status('Parsing channel description as XML')
233 d = xml.dom.minidom.parseString(channel.channel_html)
234 v.ok()
235
236 v.status('Extracting release name:')
237 title_name = d.getElementsByTagName(
238 'title')[0].firstChild.nodeValue.split()[2]
239 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
240 v.status(title_name)
241 v.result(title_name == h1_name)
242 channel.release_name = title_name
243
244 v.status('Extracting git commit:')
245 git_commit_node = d.getElementsByTagName('tt')[0]
246 channel.git_revision = git_commit_node.firstChild.nodeValue
247 v.status(channel.git_revision)
248 v.ok()
249 v.status('Verifying git commit label')
250 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
251
252 v.status('Parsing table')
253 channel.table = {}
254 for row in d.getElementsByTagName('tr')[1:]:
255 name = row.childNodes[0].firstChild.firstChild.nodeValue
256 url = row.childNodes[0].firstChild.getAttribute('href')
257 size = int(row.childNodes[1].firstChild.nodeValue)
258 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
259 channel.table[name] = ChannelTableEntry(
260 url=url, digest=digest, size=size)
261 v.ok()
262 return GitPin(release_name=title_name, git_revision=channel.git_revision)
263
264
265 def digest_string(s: bytes) -> Digest16:
266 return Digest16(hashlib.sha256(s).hexdigest())
267
268
269 def digest_file(filename: str) -> Digest16:
270 hasher = hashlib.sha256()
271 with open(filename, 'rb') as f:
272 # pylint: disable=cell-var-from-loop
273 for block in iter(lambda: f.read(4096), b''):
274 hasher.update(block)
275 return Digest16(hasher.hexdigest())
276
277
278 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
279 v.status('Converting digest to base16')
280 process = subprocess.run(
281 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
282 v.result(process.returncode == 0)
283 return Digest16(process.stdout.decode().strip())
284
285
286 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
287 v.status('Converting digest to base32')
288 process = subprocess.run(
289 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
290 v.result(process.returncode == 0)
291 return Digest32(process.stdout.decode().strip())
292
293
294 def fetch_with_nix_prefetch_url(
295 v: Verification,
296 url: str,
297 digest: Digest16) -> str:
298 v.status('Fetching %s' % url)
299 process = subprocess.run(
300 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
301 v.result(process.returncode == 0)
302 prefetch_digest, path, empty = process.stdout.decode().split('\n')
303 assert empty == ''
304 v.check("Verifying nix-prefetch-url's digest",
305 to_Digest16(v, Digest32(prefetch_digest)) == digest)
306 v.status("Verifying file digest")
307 file_digest = digest_file(path)
308 v.result(file_digest == digest)
309 return path # type: ignore # (for old mypy)
310
311
312 def fetch_resources(
313 v: Verification,
314 channel: ChannelSearchPath,
315 pin: GitPin) -> None:
316 for resource in ['git-revision', 'nixexprs.tar.xz']:
317 fields = channel.table[resource]
318 fields.absolute_url = urllib.parse.urljoin(
319 channel.forwarded_url, fields.url)
320 fields.file = fetch_with_nix_prefetch_url(
321 v, fields.absolute_url, fields.digest)
322 v.status('Verifying git commit on main page matches git commit in table')
323 v.result(
324 open(
325 channel.table['git-revision'].file).read(999) == pin.git_revision)
326
327
328 def git_cachedir(git_repo: str) -> str:
329 return os.path.join(
330 xdg.XDG_CACHE_HOME,
331 'pinch/git',
332 digest_string(git_repo.encode()))
333
334
335 def tarball_cache_file(channel: TarrableSearchPath) -> str:
336 return os.path.join(
337 xdg.XDG_CACHE_HOME,
338 'pinch/git-tarball',
339 '%s-%s-%s' %
340 (digest_string(channel.git_repo.encode()),
341 channel.git_revision,
342 channel.release_name))
343
344
345 def verify_git_ancestry(v: Verification, channel: TarrableSearchPath) -> None:
346 cachedir = git_cachedir(channel.git_repo)
347 v.status('Verifying rev is an ancestor of ref')
348 process = subprocess.run(['git',
349 '-C',
350 cachedir,
351 'merge-base',
352 '--is-ancestor',
353 channel.git_revision,
354 channel.git_ref])
355 v.result(process.returncode == 0)
356
357 if hasattr(channel, 'old_git_revision'):
358 v.status(
359 'Verifying rev is an ancestor of previous rev %s' %
360 channel.old_git_revision)
361 process = subprocess.run(['git',
362 '-C',
363 cachedir,
364 'merge-base',
365 '--is-ancestor',
366 channel.old_git_revision,
367 channel.git_revision])
368 v.result(process.returncode == 0)
369
370
371 def git_fetch(v: Verification, channel: TarrableSearchPath) -> None:
372 # It would be nice if we could share the nix git cache, but as of the time
373 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
374 # yet), and trying to straddle them both is too far into nix implementation
375 # details for my comfort. So we re-implement here half of nix.fetchGit.
376 # :(
377
378 cachedir = git_cachedir(channel.git_repo)
379 if not os.path.exists(cachedir):
380 v.status("Initializing git repo")
381 process = subprocess.run(
382 ['git', 'init', '--bare', cachedir])
383 v.result(process.returncode == 0)
384
385 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
386 # We don't use --force here because we want to abort and freak out if forced
387 # updates are happening.
388 process = subprocess.run(['git',
389 '-C',
390 cachedir,
391 'fetch',
392 channel.git_repo,
393 '%s:%s' % (channel.git_ref,
394 channel.git_ref)])
395 v.result(process.returncode == 0)
396
397 if hasattr(channel, 'git_revision'):
398 v.status('Verifying that fetch retrieved this rev')
399 process = subprocess.run(
400 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
401 v.result(process.returncode == 0)
402 else:
403 channel.git_revision = open(
404 os.path.join(
405 cachedir,
406 'refs',
407 'heads',
408 channel.git_ref)).read(999).strip()
409
410 verify_git_ancestry(v, channel)
411
412
413 def ensure_git_rev_available(
414 v: Verification,
415 channel: TarrableSearchPath,
416 pin: GitPin) -> None:
417 cachedir = git_cachedir(channel.git_repo)
418 if os.path.exists(cachedir):
419 v.status('Checking if we already have this rev:')
420 process = subprocess.run(
421 ['git', '-C', cachedir, 'cat-file', '-e', pin.git_revision])
422 if process.returncode == 0:
423 v.status('yes')
424 if process.returncode == 1:
425 v.status('no')
426 v.result(process.returncode == 0 or process.returncode == 1)
427 if process.returncode == 0:
428 verify_git_ancestry(v, channel)
429 return
430 git_fetch(v, channel)
431
432
433 def compare_tarball_and_git(
434 v: Verification,
435 channel: TarrableSearchPath,
436 channel_contents: str,
437 git_contents: str) -> None:
438 v.status('Comparing channel tarball with git checkout')
439 match, mismatch, errors = compare(os.path.join(
440 channel_contents, channel.release_name), git_contents)
441 v.ok()
442 v.check('%d files match' % len(match), len(match) > 0)
443 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
444 expected_errors = [
445 '.git-revision',
446 '.version-suffix',
447 'nixpkgs',
448 'programs.sqlite',
449 'svn-revision']
450 benign_errors = []
451 for ee in expected_errors:
452 if ee in errors:
453 errors.remove(ee)
454 benign_errors.append(ee)
455 v.check(
456 '%d unexpected incomparable files' %
457 len(errors),
458 len(errors) == 0)
459 v.check(
460 '(%d of %d expected incomparable files)' %
461 (len(benign_errors),
462 len(expected_errors)),
463 len(benign_errors) == len(expected_errors))
464
465
466 def extract_tarball(
467 v: Verification,
468 channel: TarrableSearchPath,
469 dest: str) -> None:
470 v.status('Extracting tarball %s' %
471 channel.table['nixexprs.tar.xz'].file)
472 shutil.unpack_archive(
473 channel.table['nixexprs.tar.xz'].file,
474 dest)
475 v.ok()
476
477
478 def git_checkout(
479 v: Verification,
480 channel: TarrableSearchPath,
481 dest: str) -> None:
482 v.status('Checking out corresponding git revision')
483 git = subprocess.Popen(['git',
484 '-C',
485 git_cachedir(channel.git_repo),
486 'archive',
487 channel.git_revision],
488 stdout=subprocess.PIPE)
489 tar = subprocess.Popen(
490 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
491 if git.stdout:
492 git.stdout.close()
493 tar.wait()
494 git.wait()
495 v.result(git.returncode == 0 and tar.returncode == 0)
496
497
498 def git_get_tarball(
499 v: Verification,
500 channel: TarrableSearchPath,
501 pin: GitPin) -> str:
502 cache_file = tarball_cache_file(channel)
503 if os.path.exists(cache_file):
504 cached_tarball = open(cache_file).read(9999)
505 if os.path.exists(cached_tarball):
506 return cached_tarball
507
508 with tempfile.TemporaryDirectory() as output_dir:
509 output_filename = os.path.join(
510 output_dir, pin.release_name + '.tar.xz')
511 with open(output_filename, 'w') as output_file:
512 v.status(
513 'Generating tarball for git revision %s' %
514 pin.git_revision)
515 git = subprocess.Popen(['git',
516 '-C',
517 git_cachedir(channel.git_repo),
518 'archive',
519 '--prefix=%s/' % pin.release_name,
520 pin.git_revision],
521 stdout=subprocess.PIPE)
522 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
523 xz.wait()
524 git.wait()
525 v.result(git.returncode == 0 and xz.returncode == 0)
526
527 v.status('Putting tarball in Nix store')
528 process = subprocess.run(
529 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
530 v.result(process.returncode == 0)
531 store_tarball = process.stdout.decode().strip()
532
533 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
534 open(cache_file, 'w').write(store_tarball)
535 return store_tarball # type: ignore # (for old mypy)
536
537
538 def check_channel_metadata(
539 v: Verification,
540 channel: TarrableSearchPath,
541 channel_contents: str) -> None:
542 v.status('Verifying git commit in channel tarball')
543 v.result(
544 open(
545 os.path.join(
546 channel_contents,
547 channel.release_name,
548 '.git-revision')).read(999) == channel.git_revision)
549
550 v.status(
551 'Verifying version-suffix is a suffix of release name %s:' %
552 channel.release_name)
553 version_suffix = open(
554 os.path.join(
555 channel_contents,
556 channel.release_name,
557 '.version-suffix')).read(999)
558 v.status(version_suffix)
559 v.result(channel.release_name.endswith(version_suffix))
560
561
562 def check_channel_contents(
563 v: Verification,
564 channel: TarrableSearchPath) -> None:
565 with tempfile.TemporaryDirectory() as channel_contents, \
566 tempfile.TemporaryDirectory() as git_contents:
567
568 extract_tarball(v, channel, channel_contents)
569 check_channel_metadata(v, channel, channel_contents)
570
571 git_checkout(v, channel, git_contents)
572
573 compare_tarball_and_git(v, channel, channel_contents, git_contents)
574
575 v.status('Removing temporary directories')
576 v.ok()
577
578
579 def git_revision_name(v: Verification, channel: TarrableSearchPath) -> str:
580 v.status('Getting commit date')
581 process = subprocess.run(['git',
582 '-C',
583 git_cachedir(channel.git_repo),
584 'log',
585 '-n1',
586 '--format=%ct-%h',
587 '--abbrev=11',
588 '--no-show-signature',
589 channel.git_revision],
590 stdout=subprocess.PIPE)
591 v.result(process.returncode == 0 and process.stdout != b'')
592 return '%s-%s' % (os.path.basename(channel.git_repo),
593 process.stdout.decode().strip())
594
595
596 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
597 mapping: Mapping[str, Type[SearchPath]] = {
598 'alias': AliasSearchPath,
599 'channel': ChannelSearchPath,
600 'git': GitSearchPath,
601 }
602 return mapping[conf['type']](**dict(conf.items()))
603
604
605 def read_config(filename: str) -> configparser.ConfigParser:
606 config = configparser.ConfigParser()
607 config.read_file(open(filename), filename)
608 return config
609
610
611 def read_config_files(
612 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
613 merged_config: Dict[str, configparser.SectionProxy] = {}
614 for file in filenames:
615 config = read_config(file)
616 for section in config.sections():
617 if section in merged_config:
618 raise Exception('Duplicate channel "%s"' % section)
619 merged_config[section] = config[section]
620 return merged_config
621
622
623 def pinCommand(args: argparse.Namespace) -> None:
624 v = Verification()
625 config = read_config(args.channels_file)
626 for section in config.sections():
627 if args.channels and section not in args.channels:
628 continue
629
630 sp = read_search_path(config[section])
631
632 config[section].update(sp.pin(v)._asdict())
633
634 with open(args.channels_file, 'w') as configfile:
635 config.write(configfile)
636
637
638 def updateCommand(args: argparse.Namespace) -> None:
639 v = Verification()
640 exprs: Dict[str, str] = {}
641 config = read_config_files(args.channels_file)
642 for section in config:
643 sp = read_search_path(config[section])
644 if isinstance(sp, AliasSearchPath):
645 assert 'git_repo' not in config[section]
646 continue
647 tarball = sp.fetch(v, section, config[section])
648 exprs[section] = (
649 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
650 (config[section]['release_name'], tarball))
651
652 for section in config:
653 if 'alias_of' in config[section]:
654 exprs[section] = exprs[str(config[section]['alias_of'])]
655
656 command = [
657 'nix-env',
658 '--profile',
659 '/nix/var/nix/profiles/per-user/%s/channels' %
660 getpass.getuser(),
661 '--show-trace',
662 '--file',
663 '<nix/unpack-channel.nix>',
664 '--install',
665 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
666 if args.dry_run:
667 print(' '.join(map(shlex.quote, command)))
668 else:
669 v.status('Installing channels with nix-env')
670 process = subprocess.run(command)
671 v.result(process.returncode == 0)
672
673
674 def main() -> None:
675 parser = argparse.ArgumentParser(prog='pinch')
676 subparsers = parser.add_subparsers(dest='mode', required=True)
677 parser_pin = subparsers.add_parser('pin')
678 parser_pin.add_argument('channels_file', type=str)
679 parser_pin.add_argument('channels', type=str, nargs='*')
680 parser_pin.set_defaults(func=pinCommand)
681 parser_update = subparsers.add_parser('update')
682 parser_update.add_argument('--dry-run', action='store_true')
683 parser_update.add_argument('channels_file', type=str, nargs='+')
684 parser_update.set_defaults(func=updateCommand)
685 args = parser.parse_args()
686 args.func(args)
687
688
689 if __name__ == '__main__':
690 main()