]> git.scottworley.com Git - pinch/blob - pinch.py
916739d32797d6179f4ebdf553abeeb69b354f8f
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 Mapping,
26 NamedTuple,
27 NewType,
28 Optional,
29 Tuple,
30 Type,
31 Union,
32 )
33
34 # Use xdg module when it's less painful to have as a dependency
35
36
37 class XDG(types.SimpleNamespace):
38 XDG_CACHE_HOME: str
39
40
41 xdg = XDG(
42 XDG_CACHE_HOME=os.getenv(
43 'XDG_CACHE_HOME',
44 os.path.expanduser('~/.cache')))
45
46
47 class VerificationError(Exception):
48 pass
49
50
51 class Verification:
52
53 def __init__(self) -> None:
54 self.line_length = 0
55
56 def status(self, s: str) -> None:
57 print(s, end=' ', file=sys.stderr, flush=True)
58 self.line_length += 1 + len(s) # Unicode??
59
60 @staticmethod
61 def _color(s: str, c: int) -> str:
62 return '\033[%2dm%s\033[00m' % (c, s)
63
64 def result(self, r: bool) -> None:
65 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
66 length = len(message)
67 cols = shutil.get_terminal_size().columns or 80
68 pad = (cols - (self.line_length + length)) % cols
69 print(' ' * pad + self._color(message, color), file=sys.stderr)
70 self.line_length = 0
71 if not r:
72 raise VerificationError()
73
74 def check(self, s: str, r: bool) -> None:
75 self.status(s)
76 self.result(r)
77
78 def ok(self) -> None:
79 self.result(True)
80
81
82 Digest16 = NewType('Digest16', str)
83 Digest32 = NewType('Digest32', str)
84
85
86 class ChannelTableEntry(types.SimpleNamespace):
87 absolute_url: str
88 digest: Digest16
89 file: str
90 size: int
91 url: str
92
93
94 class AliasPin(NamedTuple):
95 pass
96
97
98 class GitPin(NamedTuple):
99 git_revision: str
100 release_name: str
101
102
103 class ChannelPin(NamedTuple):
104 git_revision: str
105 release_name: str
106 tarball_url: str
107 tarball_sha256: str
108
109
110 Pin = Union[AliasPin, GitPin, ChannelPin]
111
112
113 class SearchPath(types.SimpleNamespace, ABC):
114
115 @abstractmethod
116 def pin(self, v: Verification) -> Pin:
117 pass
118
119
120 class AliasSearchPath(SearchPath):
121 alias_of: str
122
123 def pin(self, v: Verification) -> AliasPin:
124 assert not hasattr(self, 'git_repo')
125 return AliasPin()
126
127
128 # (This lint-disable is for pylint bug https://github.com/PyCQA/pylint/issues/179
129 # which is fixed in pylint 2.5.)
130 class TarrableSearchPath(SearchPath, ABC): # pylint: disable=abstract-method
131 channel_html: bytes
132 channel_url: str
133 forwarded_url: str
134 git_ref: str
135 git_repo: str
136 table: Dict[str, ChannelTableEntry]
137
138
139 class GitSearchPath(TarrableSearchPath):
140 def pin(self, v: Verification) -> GitPin:
141 old_revision = (
142 self.git_revision if hasattr(self, 'git_revision') else None)
143 if hasattr(self, 'git_revision'):
144 del self.git_revision
145
146 new_revision = git_fetch(v, self, None, old_revision)
147 return GitPin(release_name=git_revision_name(v, self, new_revision),
148 git_revision=new_revision)
149
150 def fetch(self, v: Verification, section: str,
151 conf: configparser.SectionProxy) -> str:
152 if 'git_revision' not in conf or 'release_name' not in conf:
153 raise Exception(
154 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
155 section)
156 the_pin = GitPin(
157 release_name=conf['release_name'],
158 git_revision=conf['git_revision'])
159
160 ensure_git_rev_available(v, self, the_pin, None)
161 return git_get_tarball(v, self, the_pin)
162
163
164 class ChannelSearchPath(TarrableSearchPath):
165 def pin(self, v: Verification) -> ChannelPin:
166 old_revision = (
167 self.git_revision if hasattr(self, 'git_revision') else None)
168 if hasattr(self, 'git_revision'):
169 del self.git_revision
170
171 fetch(v, self)
172 new_gitpin = parse_channel(v, self)
173 fetch_resources(v, self, new_gitpin)
174 ensure_git_rev_available(v, self, new_gitpin, old_revision)
175 check_channel_contents(v, self)
176 return ChannelPin(
177 release_name=new_gitpin.release_name,
178 tarball_url=self.table['nixexprs.tar.xz'].absolute_url,
179 tarball_sha256=self.table['nixexprs.tar.xz'].digest,
180 git_revision=self.git_revision)
181
182 # Lint TODO: Put tarball_url and tarball_sha256 in ChannelSearchPath
183 # pylint: disable=no-self-use
184 def fetch(self, v: Verification, section: str,
185 conf: configparser.SectionProxy) -> str:
186 if 'git_repo' not in conf or 'release_name' not in conf:
187 raise Exception(
188 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
189 section)
190
191 return fetch_with_nix_prefetch_url(
192 v, conf['tarball_url'], Digest16(
193 conf['tarball_sha256']))
194
195
196 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
197
198 def throw(error: OSError) -> None:
199 raise error
200
201 def join(x: str, y: str) -> str:
202 return y if x == '.' else os.path.join(x, y)
203
204 def recursive_files(d: str) -> Iterable[str]:
205 all_files: List[str] = []
206 for path, dirs, files in os.walk(d, onerror=throw):
207 rel = os.path.relpath(path, start=d)
208 all_files.extend(join(rel, f) for f in files)
209 for dir_or_link in dirs:
210 if os.path.islink(join(path, dir_or_link)):
211 all_files.append(join(rel, dir_or_link))
212 return all_files
213
214 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
215 return (f for f in files if not f.startswith('.git/'))
216
217 files = functools.reduce(
218 operator.or_, (set(
219 exclude_dot_git(
220 recursive_files(x))) for x in [a, b]))
221 return filecmp.cmpfiles(a, b, files, shallow=False)
222
223
224 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
225 v.status('Fetching channel')
226 request = urllib.request.urlopen(channel.channel_url, timeout=10)
227 channel.channel_html = request.read()
228 channel.forwarded_url = request.geturl()
229 v.result(request.status == 200) # type: ignore # (for old mypy)
230 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
231
232
233 def parse_channel(v: Verification, channel: TarrableSearchPath) -> GitPin:
234 v.status('Parsing channel description as XML')
235 d = xml.dom.minidom.parseString(channel.channel_html)
236 v.ok()
237
238 v.status('Extracting release name:')
239 title_name = d.getElementsByTagName(
240 'title')[0].firstChild.nodeValue.split()[2]
241 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
242 v.status(title_name)
243 v.result(title_name == h1_name)
244
245 v.status('Extracting git commit:')
246 git_commit_node = d.getElementsByTagName('tt')[0]
247 channel.git_revision = git_commit_node.firstChild.nodeValue
248 v.status(channel.git_revision)
249 v.ok()
250 v.status('Verifying git commit label')
251 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
252
253 v.status('Parsing table')
254 channel.table = {}
255 for row in d.getElementsByTagName('tr')[1:]:
256 name = row.childNodes[0].firstChild.firstChild.nodeValue
257 url = row.childNodes[0].firstChild.getAttribute('href')
258 size = int(row.childNodes[1].firstChild.nodeValue)
259 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
260 channel.table[name] = ChannelTableEntry(
261 url=url, digest=digest, size=size)
262 v.ok()
263 return GitPin(release_name=title_name, git_revision=channel.git_revision)
264
265
266 def digest_string(s: bytes) -> Digest16:
267 return Digest16(hashlib.sha256(s).hexdigest())
268
269
270 def digest_file(filename: str) -> Digest16:
271 hasher = hashlib.sha256()
272 with open(filename, 'rb') as f:
273 # pylint: disable=cell-var-from-loop
274 for block in iter(lambda: f.read(4096), b''):
275 hasher.update(block)
276 return Digest16(hasher.hexdigest())
277
278
279 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
280 v.status('Converting digest to base16')
281 process = subprocess.run(
282 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
283 v.result(process.returncode == 0)
284 return Digest16(process.stdout.decode().strip())
285
286
287 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
288 v.status('Converting digest to base32')
289 process = subprocess.run(
290 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
291 v.result(process.returncode == 0)
292 return Digest32(process.stdout.decode().strip())
293
294
295 def fetch_with_nix_prefetch_url(
296 v: Verification,
297 url: str,
298 digest: Digest16) -> str:
299 v.status('Fetching %s' % url)
300 process = subprocess.run(
301 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
302 v.result(process.returncode == 0)
303 prefetch_digest, path, empty = process.stdout.decode().split('\n')
304 assert empty == ''
305 v.check("Verifying nix-prefetch-url's digest",
306 to_Digest16(v, Digest32(prefetch_digest)) == digest)
307 v.status("Verifying file digest")
308 file_digest = digest_file(path)
309 v.result(file_digest == digest)
310 return path # type: ignore # (for old mypy)
311
312
313 def fetch_resources(
314 v: Verification,
315 channel: ChannelSearchPath,
316 pin: GitPin) -> None:
317 for resource in ['git-revision', 'nixexprs.tar.xz']:
318 fields = channel.table[resource]
319 fields.absolute_url = urllib.parse.urljoin(
320 channel.forwarded_url, fields.url)
321 fields.file = fetch_with_nix_prefetch_url(
322 v, fields.absolute_url, fields.digest)
323 v.status('Verifying git commit on main page matches git commit in table')
324 v.result(
325 open(
326 channel.table['git-revision'].file).read(999) == pin.git_revision)
327
328
329 def git_cachedir(git_repo: str) -> str:
330 return os.path.join(
331 xdg.XDG_CACHE_HOME,
332 'pinch/git',
333 digest_string(git_repo.encode()))
334
335
336 def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
337 return os.path.join(
338 xdg.XDG_CACHE_HOME,
339 'pinch/git-tarball',
340 '%s-%s-%s' %
341 (digest_string(channel.git_repo.encode()),
342 pin.git_revision,
343 pin.release_name))
344
345
346 def verify_git_ancestry(
347 v: Verification,
348 channel: TarrableSearchPath,
349 new_revision: str,
350 old_revision: Optional[str]) -> None:
351 cachedir = git_cachedir(channel.git_repo)
352 v.status('Verifying rev is an ancestor of ref')
353 process = subprocess.run(['git',
354 '-C',
355 cachedir,
356 'merge-base',
357 '--is-ancestor',
358 new_revision,
359 channel.git_ref])
360 v.result(process.returncode == 0)
361
362 if old_revision is not None:
363 v.status(
364 'Verifying rev is an ancestor of previous rev %s' %
365 old_revision)
366 process = subprocess.run(['git',
367 '-C',
368 cachedir,
369 'merge-base',
370 '--is-ancestor',
371 old_revision,
372 new_revision])
373 v.result(process.returncode == 0)
374
375
376 def git_fetch(
377 v: Verification,
378 channel: TarrableSearchPath,
379 desired_revision: Optional[str],
380 old_revision: Optional[str]) -> str:
381 # It would be nice if we could share the nix git cache, but as of the time
382 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
383 # yet), and trying to straddle them both is too far into nix implementation
384 # details for my comfort. So we re-implement here half of nix.fetchGit.
385 # :(
386
387 cachedir = git_cachedir(channel.git_repo)
388 if not os.path.exists(cachedir):
389 v.status("Initializing git repo")
390 process = subprocess.run(
391 ['git', 'init', '--bare', cachedir])
392 v.result(process.returncode == 0)
393
394 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
395 # We don't use --force here because we want to abort and freak out if forced
396 # updates are happening.
397 process = subprocess.run(['git',
398 '-C',
399 cachedir,
400 'fetch',
401 channel.git_repo,
402 '%s:%s' % (channel.git_ref,
403 channel.git_ref)])
404 v.result(process.returncode == 0)
405
406 if desired_revision is not None:
407 v.status('Verifying that fetch retrieved this rev')
408 process = subprocess.run(
409 ['git', '-C', cachedir, 'cat-file', '-e', desired_revision])
410 v.result(process.returncode == 0)
411
412 new_revision = open(
413 os.path.join(
414 cachedir,
415 'refs',
416 'heads',
417 channel.git_ref)).read(999).strip()
418
419 verify_git_ancestry(v, channel, new_revision, old_revision)
420
421 return new_revision
422
423
424 def ensure_git_rev_available(
425 v: Verification,
426 channel: TarrableSearchPath,
427 pin: GitPin,
428 old_revision: Optional[str]) -> None:
429 cachedir = git_cachedir(channel.git_repo)
430 if os.path.exists(cachedir):
431 v.status('Checking if we already have this rev:')
432 process = subprocess.run(
433 ['git', '-C', cachedir, 'cat-file', '-e', pin.git_revision])
434 if process.returncode == 0:
435 v.status('yes')
436 if process.returncode == 1:
437 v.status('no')
438 v.result(process.returncode == 0 or process.returncode == 1)
439 if process.returncode == 0:
440 verify_git_ancestry(v, channel, pin.git_revision, old_revision)
441 return
442 git_fetch(v, channel, pin.git_revision, old_revision)
443
444
445 def compare_tarball_and_git(
446 v: Verification,
447 channel: TarrableSearchPath,
448 channel_contents: str,
449 git_contents: str) -> None:
450 v.status('Comparing channel tarball with git checkout')
451 match, mismatch, errors = compare(os.path.join(
452 channel_contents, channel.release_name), git_contents)
453 v.ok()
454 v.check('%d files match' % len(match), len(match) > 0)
455 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
456 expected_errors = [
457 '.git-revision',
458 '.version-suffix',
459 'nixpkgs',
460 'programs.sqlite',
461 'svn-revision']
462 benign_errors = []
463 for ee in expected_errors:
464 if ee in errors:
465 errors.remove(ee)
466 benign_errors.append(ee)
467 v.check(
468 '%d unexpected incomparable files' %
469 len(errors),
470 len(errors) == 0)
471 v.check(
472 '(%d of %d expected incomparable files)' %
473 (len(benign_errors),
474 len(expected_errors)),
475 len(benign_errors) == len(expected_errors))
476
477
478 def extract_tarball(
479 v: Verification,
480 channel: TarrableSearchPath,
481 dest: str) -> None:
482 v.status('Extracting tarball %s' %
483 channel.table['nixexprs.tar.xz'].file)
484 shutil.unpack_archive(
485 channel.table['nixexprs.tar.xz'].file,
486 dest)
487 v.ok()
488
489
490 def git_checkout(
491 v: Verification,
492 channel: TarrableSearchPath,
493 dest: str) -> None:
494 v.status('Checking out corresponding git revision')
495 git = subprocess.Popen(['git',
496 '-C',
497 git_cachedir(channel.git_repo),
498 'archive',
499 channel.git_revision],
500 stdout=subprocess.PIPE)
501 tar = subprocess.Popen(
502 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
503 if git.stdout:
504 git.stdout.close()
505 tar.wait()
506 git.wait()
507 v.result(git.returncode == 0 and tar.returncode == 0)
508
509
510 def git_get_tarball(
511 v: Verification,
512 channel: TarrableSearchPath,
513 pin: GitPin) -> str:
514 cache_file = tarball_cache_file(channel, pin)
515 if os.path.exists(cache_file):
516 cached_tarball = open(cache_file).read(9999)
517 if os.path.exists(cached_tarball):
518 return cached_tarball
519
520 with tempfile.TemporaryDirectory() as output_dir:
521 output_filename = os.path.join(
522 output_dir, pin.release_name + '.tar.xz')
523 with open(output_filename, 'w') as output_file:
524 v.status(
525 'Generating tarball for git revision %s' %
526 pin.git_revision)
527 git = subprocess.Popen(['git',
528 '-C',
529 git_cachedir(channel.git_repo),
530 'archive',
531 '--prefix=%s/' % pin.release_name,
532 pin.git_revision],
533 stdout=subprocess.PIPE)
534 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
535 xz.wait()
536 git.wait()
537 v.result(git.returncode == 0 and xz.returncode == 0)
538
539 v.status('Putting tarball in Nix store')
540 process = subprocess.run(
541 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
542 v.result(process.returncode == 0)
543 store_tarball = process.stdout.decode().strip()
544
545 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
546 open(cache_file, 'w').write(store_tarball)
547 return store_tarball # type: ignore # (for old mypy)
548
549
550 def check_channel_metadata(
551 v: Verification,
552 channel: TarrableSearchPath,
553 channel_contents: str) -> None:
554 v.status('Verifying git commit in channel tarball')
555 v.result(
556 open(
557 os.path.join(
558 channel_contents,
559 channel.release_name,
560 '.git-revision')).read(999) == channel.git_revision)
561
562 v.status(
563 'Verifying version-suffix is a suffix of release name %s:' %
564 channel.release_name)
565 version_suffix = open(
566 os.path.join(
567 channel_contents,
568 channel.release_name,
569 '.version-suffix')).read(999)
570 v.status(version_suffix)
571 v.result(channel.release_name.endswith(version_suffix))
572
573
574 def check_channel_contents(
575 v: Verification,
576 channel: TarrableSearchPath) -> None:
577 with tempfile.TemporaryDirectory() as channel_contents, \
578 tempfile.TemporaryDirectory() as git_contents:
579
580 extract_tarball(v, channel, channel_contents)
581 check_channel_metadata(v, channel, channel_contents)
582
583 git_checkout(v, channel, git_contents)
584
585 compare_tarball_and_git(v, channel, channel_contents, git_contents)
586
587 v.status('Removing temporary directories')
588 v.ok()
589
590
591 def git_revision_name(
592 v: Verification,
593 channel: TarrableSearchPath,
594 git_revision: str) -> str:
595 v.status('Getting commit date')
596 process = subprocess.run(['git',
597 '-C',
598 git_cachedir(channel.git_repo),
599 'log',
600 '-n1',
601 '--format=%ct-%h',
602 '--abbrev=11',
603 '--no-show-signature',
604 git_revision],
605 stdout=subprocess.PIPE)
606 v.result(process.returncode == 0 and process.stdout != b'')
607 return '%s-%s' % (os.path.basename(channel.git_repo),
608 process.stdout.decode().strip())
609
610
611 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
612 mapping: Mapping[str, Type[SearchPath]] = {
613 'alias': AliasSearchPath,
614 'channel': ChannelSearchPath,
615 'git': GitSearchPath,
616 }
617 return mapping[conf['type']](**dict(conf.items()))
618
619
620 def read_config(filename: str) -> configparser.ConfigParser:
621 config = configparser.ConfigParser()
622 config.read_file(open(filename), filename)
623 return config
624
625
626 def read_config_files(
627 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
628 merged_config: Dict[str, configparser.SectionProxy] = {}
629 for file in filenames:
630 config = read_config(file)
631 for section in config.sections():
632 if section in merged_config:
633 raise Exception('Duplicate channel "%s"' % section)
634 merged_config[section] = config[section]
635 return merged_config
636
637
638 def pinCommand(args: argparse.Namespace) -> None:
639 v = Verification()
640 config = read_config(args.channels_file)
641 for section in config.sections():
642 if args.channels and section not in args.channels:
643 continue
644
645 sp = read_search_path(config[section])
646
647 config[section].update(sp.pin(v)._asdict())
648
649 with open(args.channels_file, 'w') as configfile:
650 config.write(configfile)
651
652
653 def updateCommand(args: argparse.Namespace) -> None:
654 v = Verification()
655 exprs: Dict[str, str] = {}
656 config = read_config_files(args.channels_file)
657 for section in config:
658 sp = read_search_path(config[section])
659 if isinstance(sp, AliasSearchPath):
660 assert 'git_repo' not in config[section]
661 continue
662 tarball = sp.fetch(v, section, config[section])
663 exprs[section] = (
664 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
665 (config[section]['release_name'], tarball))
666
667 for section in config:
668 if 'alias_of' in config[section]:
669 exprs[section] = exprs[str(config[section]['alias_of'])]
670
671 command = [
672 'nix-env',
673 '--profile',
674 '/nix/var/nix/profiles/per-user/%s/channels' %
675 getpass.getuser(),
676 '--show-trace',
677 '--file',
678 '<nix/unpack-channel.nix>',
679 '--install',
680 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
681 if args.dry_run:
682 print(' '.join(map(shlex.quote, command)))
683 else:
684 v.status('Installing channels with nix-env')
685 process = subprocess.run(command)
686 v.result(process.returncode == 0)
687
688
689 def main() -> None:
690 parser = argparse.ArgumentParser(prog='pinch')
691 subparsers = parser.add_subparsers(dest='mode', required=True)
692 parser_pin = subparsers.add_parser('pin')
693 parser_pin.add_argument('channels_file', type=str)
694 parser_pin.add_argument('channels', type=str, nargs='*')
695 parser_pin.set_defaults(func=pinCommand)
696 parser_update = subparsers.add_parser('update')
697 parser_update.add_argument('--dry-run', action='store_true')
698 parser_update.add_argument('channels_file', type=str, nargs='+')
699 parser_update.set_defaults(func=updateCommand)
700 args = parser.parse_args()
701 args.func(args)
702
703
704 if __name__ == '__main__':
705 main()