]> git.scottworley.com Git - pinch/blob - pinch.py
Require type to be specified in config
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 Mapping,
26 NamedTuple,
27 NewType,
28 Tuple,
29 Type,
30 Union,
31 )
32
33 # Use xdg module when it's less painful to have as a dependency
34
35
36 class XDG(types.SimpleNamespace):
37 XDG_CACHE_HOME: str
38
39
40 xdg = XDG(
41 XDG_CACHE_HOME=os.getenv(
42 'XDG_CACHE_HOME',
43 os.path.expanduser('~/.cache')))
44
45
46 class VerificationError(Exception):
47 pass
48
49
50 class Verification:
51
52 def __init__(self) -> None:
53 self.line_length = 0
54
55 def status(self, s: str) -> None:
56 print(s, end=' ', file=sys.stderr, flush=True)
57 self.line_length += 1 + len(s) # Unicode??
58
59 @staticmethod
60 def _color(s: str, c: int) -> str:
61 return '\033[%2dm%s\033[00m' % (c, s)
62
63 def result(self, r: bool) -> None:
64 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
65 length = len(message)
66 cols = shutil.get_terminal_size().columns or 80
67 pad = (cols - (self.line_length + length)) % cols
68 print(' ' * pad + self._color(message, color), file=sys.stderr)
69 self.line_length = 0
70 if not r:
71 raise VerificationError()
72
73 def check(self, s: str, r: bool) -> None:
74 self.status(s)
75 self.result(r)
76
77 def ok(self) -> None:
78 self.result(True)
79
80
81 Digest16 = NewType('Digest16', str)
82 Digest32 = NewType('Digest32', str)
83
84
85 class ChannelTableEntry(types.SimpleNamespace):
86 absolute_url: str
87 digest: Digest16
88 file: str
89 size: int
90 url: str
91
92
93 class AliasPin(NamedTuple):
94 pass
95
96
97 class GitPin(NamedTuple):
98 git_revision: str
99 release_name: str
100
101
102 class ChannelPin(NamedTuple):
103 git_revision: str
104 release_name: str
105 tarball_url: str
106 tarball_sha256: str
107
108
109 Pin = Union[AliasPin, GitPin, ChannelPin]
110
111
112 class SearchPath(types.SimpleNamespace, ABC):
113
114 @abstractmethod
115 def pin(self, v: Verification) -> Pin:
116 pass
117
118
119 class AliasSearchPath(SearchPath):
120 alias_of: str
121
122 def pin(self, v: Verification) -> AliasPin:
123 assert not hasattr(self, 'git_repo')
124 return AliasPin()
125
126
127 # (This lint-disable is for pylint bug https://github.com/PyCQA/pylint/issues/179
128 # which is fixed in pylint 2.5.)
129 class TarrableSearchPath(SearchPath, ABC): # pylint: disable=abstract-method
130 channel_html: bytes
131 channel_url: str
132 forwarded_url: str
133 git_ref: str
134 git_repo: str
135 old_git_revision: str
136 table: Dict[str, ChannelTableEntry]
137
138
139 class GitSearchPath(TarrableSearchPath):
140 def pin(self, v: Verification) -> GitPin:
141 if hasattr(self, 'git_revision'):
142 self.old_git_revision = self.git_revision
143 del self.git_revision
144
145 git_fetch(v, self)
146 return GitPin(release_name=git_revision_name(v, self),
147 git_revision=self.git_revision)
148
149 def fetch(self, v: Verification, section: str,
150 conf: configparser.SectionProxy) -> str:
151 if 'git_repo' not in conf or 'release_name' not in conf:
152 raise Exception(
153 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
154 section)
155
156 ensure_git_rev_available(v, self)
157 return git_get_tarball(v, self)
158
159
160 class ChannelSearchPath(TarrableSearchPath):
161 def pin(self, v: Verification) -> ChannelPin:
162 if hasattr(self, 'git_revision'):
163 self.old_git_revision = self.git_revision
164 del self.git_revision
165
166 pin_channel(v, self)
167 return ChannelPin(
168 release_name=self.release_name,
169 tarball_url=self.table['nixexprs.tar.xz'].absolute_url,
170 tarball_sha256=self.table['nixexprs.tar.xz'].digest,
171 git_revision=self.git_revision)
172
173 # Lint TODO: Put tarball_url and tarball_sha256 in ChannelSearchPath
174 # pylint: disable=no-self-use
175 def fetch(self, v: Verification, section: str,
176 conf: configparser.SectionProxy) -> str:
177 if 'git_repo' not in conf or 'release_name' not in conf:
178 raise Exception(
179 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
180 section)
181
182 return fetch_with_nix_prefetch_url(
183 v, conf['tarball_url'], Digest16(
184 conf['tarball_sha256']))
185
186
187 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
188
189 def throw(error: OSError) -> None:
190 raise error
191
192 def join(x: str, y: str) -> str:
193 return y if x == '.' else os.path.join(x, y)
194
195 def recursive_files(d: str) -> Iterable[str]:
196 all_files: List[str] = []
197 for path, dirs, files in os.walk(d, onerror=throw):
198 rel = os.path.relpath(path, start=d)
199 all_files.extend(join(rel, f) for f in files)
200 for dir_or_link in dirs:
201 if os.path.islink(join(path, dir_or_link)):
202 all_files.append(join(rel, dir_or_link))
203 return all_files
204
205 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
206 return (f for f in files if not f.startswith('.git/'))
207
208 files = functools.reduce(
209 operator.or_, (set(
210 exclude_dot_git(
211 recursive_files(x))) for x in [a, b]))
212 return filecmp.cmpfiles(a, b, files, shallow=False)
213
214
215 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
216 v.status('Fetching channel')
217 request = urllib.request.urlopen(channel.channel_url, timeout=10)
218 channel.channel_html = request.read()
219 channel.forwarded_url = request.geturl()
220 v.result(request.status == 200) # type: ignore # (for old mypy)
221 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
222
223
224 def parse_channel(v: Verification, channel: TarrableSearchPath) -> None:
225 v.status('Parsing channel description as XML')
226 d = xml.dom.minidom.parseString(channel.channel_html)
227 v.ok()
228
229 v.status('Extracting release name:')
230 title_name = d.getElementsByTagName(
231 'title')[0].firstChild.nodeValue.split()[2]
232 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
233 v.status(title_name)
234 v.result(title_name == h1_name)
235 channel.release_name = title_name
236
237 v.status('Extracting git commit:')
238 git_commit_node = d.getElementsByTagName('tt')[0]
239 channel.git_revision = git_commit_node.firstChild.nodeValue
240 v.status(channel.git_revision)
241 v.ok()
242 v.status('Verifying git commit label')
243 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
244
245 v.status('Parsing table')
246 channel.table = {}
247 for row in d.getElementsByTagName('tr')[1:]:
248 name = row.childNodes[0].firstChild.firstChild.nodeValue
249 url = row.childNodes[0].firstChild.getAttribute('href')
250 size = int(row.childNodes[1].firstChild.nodeValue)
251 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
252 channel.table[name] = ChannelTableEntry(
253 url=url, digest=digest, size=size)
254 v.ok()
255
256
257 def digest_string(s: bytes) -> Digest16:
258 return Digest16(hashlib.sha256(s).hexdigest())
259
260
261 def digest_file(filename: str) -> Digest16:
262 hasher = hashlib.sha256()
263 with open(filename, 'rb') as f:
264 # pylint: disable=cell-var-from-loop
265 for block in iter(lambda: f.read(4096), b''):
266 hasher.update(block)
267 return Digest16(hasher.hexdigest())
268
269
270 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
271 v.status('Converting digest to base16')
272 process = subprocess.run(
273 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
274 v.result(process.returncode == 0)
275 return Digest16(process.stdout.decode().strip())
276
277
278 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
279 v.status('Converting digest to base32')
280 process = subprocess.run(
281 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
282 v.result(process.returncode == 0)
283 return Digest32(process.stdout.decode().strip())
284
285
286 def fetch_with_nix_prefetch_url(
287 v: Verification,
288 url: str,
289 digest: Digest16) -> str:
290 v.status('Fetching %s' % url)
291 process = subprocess.run(
292 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
293 v.result(process.returncode == 0)
294 prefetch_digest, path, empty = process.stdout.decode().split('\n')
295 assert empty == ''
296 v.check("Verifying nix-prefetch-url's digest",
297 to_Digest16(v, Digest32(prefetch_digest)) == digest)
298 v.status("Verifying file digest")
299 file_digest = digest_file(path)
300 v.result(file_digest == digest)
301 return path # type: ignore # (for old mypy)
302
303
304 def fetch_resources(v: Verification, channel: TarrableSearchPath) -> None:
305 for resource in ['git-revision', 'nixexprs.tar.xz']:
306 fields = channel.table[resource]
307 fields.absolute_url = urllib.parse.urljoin(
308 channel.forwarded_url, fields.url)
309 fields.file = fetch_with_nix_prefetch_url(
310 v, fields.absolute_url, fields.digest)
311 v.status('Verifying git commit on main page matches git commit in table')
312 v.result(
313 open(
314 channel.table['git-revision'].file).read(999) == channel.git_revision)
315
316
317 def git_cachedir(git_repo: str) -> str:
318 return os.path.join(
319 xdg.XDG_CACHE_HOME,
320 'pinch/git',
321 digest_string(git_repo.encode()))
322
323
324 def tarball_cache_file(channel: TarrableSearchPath) -> str:
325 return os.path.join(
326 xdg.XDG_CACHE_HOME,
327 'pinch/git-tarball',
328 '%s-%s-%s' %
329 (digest_string(channel.git_repo.encode()),
330 channel.git_revision,
331 channel.release_name))
332
333
334 def verify_git_ancestry(v: Verification, channel: TarrableSearchPath) -> None:
335 cachedir = git_cachedir(channel.git_repo)
336 v.status('Verifying rev is an ancestor of ref')
337 process = subprocess.run(['git',
338 '-C',
339 cachedir,
340 'merge-base',
341 '--is-ancestor',
342 channel.git_revision,
343 channel.git_ref])
344 v.result(process.returncode == 0)
345
346 if hasattr(channel, 'old_git_revision'):
347 v.status(
348 'Verifying rev is an ancestor of previous rev %s' %
349 channel.old_git_revision)
350 process = subprocess.run(['git',
351 '-C',
352 cachedir,
353 'merge-base',
354 '--is-ancestor',
355 channel.old_git_revision,
356 channel.git_revision])
357 v.result(process.returncode == 0)
358
359
360 def git_fetch(v: Verification, channel: TarrableSearchPath) -> None:
361 # It would be nice if we could share the nix git cache, but as of the time
362 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
363 # yet), and trying to straddle them both is too far into nix implementation
364 # details for my comfort. So we re-implement here half of nix.fetchGit.
365 # :(
366
367 cachedir = git_cachedir(channel.git_repo)
368 if not os.path.exists(cachedir):
369 v.status("Initializing git repo")
370 process = subprocess.run(
371 ['git', 'init', '--bare', cachedir])
372 v.result(process.returncode == 0)
373
374 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
375 # We don't use --force here because we want to abort and freak out if forced
376 # updates are happening.
377 process = subprocess.run(['git',
378 '-C',
379 cachedir,
380 'fetch',
381 channel.git_repo,
382 '%s:%s' % (channel.git_ref,
383 channel.git_ref)])
384 v.result(process.returncode == 0)
385
386 if hasattr(channel, 'git_revision'):
387 v.status('Verifying that fetch retrieved this rev')
388 process = subprocess.run(
389 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
390 v.result(process.returncode == 0)
391 else:
392 channel.git_revision = open(
393 os.path.join(
394 cachedir,
395 'refs',
396 'heads',
397 channel.git_ref)).read(999).strip()
398
399 verify_git_ancestry(v, channel)
400
401
402 def ensure_git_rev_available(
403 v: Verification,
404 channel: TarrableSearchPath) -> None:
405 cachedir = git_cachedir(channel.git_repo)
406 if os.path.exists(cachedir):
407 v.status('Checking if we already have this rev:')
408 process = subprocess.run(
409 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
410 if process.returncode == 0:
411 v.status('yes')
412 if process.returncode == 1:
413 v.status('no')
414 v.result(process.returncode == 0 or process.returncode == 1)
415 if process.returncode == 0:
416 verify_git_ancestry(v, channel)
417 return
418 git_fetch(v, channel)
419
420
421 def compare_tarball_and_git(
422 v: Verification,
423 channel: TarrableSearchPath,
424 channel_contents: str,
425 git_contents: str) -> None:
426 v.status('Comparing channel tarball with git checkout')
427 match, mismatch, errors = compare(os.path.join(
428 channel_contents, channel.release_name), git_contents)
429 v.ok()
430 v.check('%d files match' % len(match), len(match) > 0)
431 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
432 expected_errors = [
433 '.git-revision',
434 '.version-suffix',
435 'nixpkgs',
436 'programs.sqlite',
437 'svn-revision']
438 benign_errors = []
439 for ee in expected_errors:
440 if ee in errors:
441 errors.remove(ee)
442 benign_errors.append(ee)
443 v.check(
444 '%d unexpected incomparable files' %
445 len(errors),
446 len(errors) == 0)
447 v.check(
448 '(%d of %d expected incomparable files)' %
449 (len(benign_errors),
450 len(expected_errors)),
451 len(benign_errors) == len(expected_errors))
452
453
454 def extract_tarball(
455 v: Verification,
456 channel: TarrableSearchPath,
457 dest: str) -> None:
458 v.status('Extracting tarball %s' %
459 channel.table['nixexprs.tar.xz'].file)
460 shutil.unpack_archive(
461 channel.table['nixexprs.tar.xz'].file,
462 dest)
463 v.ok()
464
465
466 def git_checkout(
467 v: Verification,
468 channel: TarrableSearchPath,
469 dest: str) -> None:
470 v.status('Checking out corresponding git revision')
471 git = subprocess.Popen(['git',
472 '-C',
473 git_cachedir(channel.git_repo),
474 'archive',
475 channel.git_revision],
476 stdout=subprocess.PIPE)
477 tar = subprocess.Popen(
478 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
479 if git.stdout:
480 git.stdout.close()
481 tar.wait()
482 git.wait()
483 v.result(git.returncode == 0 and tar.returncode == 0)
484
485
486 def git_get_tarball(v: Verification, channel: TarrableSearchPath) -> str:
487 cache_file = tarball_cache_file(channel)
488 if os.path.exists(cache_file):
489 cached_tarball = open(cache_file).read(9999)
490 if os.path.exists(cached_tarball):
491 return cached_tarball
492
493 with tempfile.TemporaryDirectory() as output_dir:
494 output_filename = os.path.join(
495 output_dir, channel.release_name + '.tar.xz')
496 with open(output_filename, 'w') as output_file:
497 v.status(
498 'Generating tarball for git revision %s' %
499 channel.git_revision)
500 git = subprocess.Popen(['git',
501 '-C',
502 git_cachedir(channel.git_repo),
503 'archive',
504 '--prefix=%s/' % channel.release_name,
505 channel.git_revision],
506 stdout=subprocess.PIPE)
507 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
508 xz.wait()
509 git.wait()
510 v.result(git.returncode == 0 and xz.returncode == 0)
511
512 v.status('Putting tarball in Nix store')
513 process = subprocess.run(
514 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
515 v.result(process.returncode == 0)
516 store_tarball = process.stdout.decode().strip()
517
518 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
519 open(cache_file, 'w').write(store_tarball)
520 return store_tarball # type: ignore # (for old mypy)
521
522
523 def check_channel_metadata(
524 v: Verification,
525 channel: TarrableSearchPath,
526 channel_contents: str) -> None:
527 v.status('Verifying git commit in channel tarball')
528 v.result(
529 open(
530 os.path.join(
531 channel_contents,
532 channel.release_name,
533 '.git-revision')).read(999) == channel.git_revision)
534
535 v.status(
536 'Verifying version-suffix is a suffix of release name %s:' %
537 channel.release_name)
538 version_suffix = open(
539 os.path.join(
540 channel_contents,
541 channel.release_name,
542 '.version-suffix')).read(999)
543 v.status(version_suffix)
544 v.result(channel.release_name.endswith(version_suffix))
545
546
547 def check_channel_contents(
548 v: Verification,
549 channel: TarrableSearchPath) -> None:
550 with tempfile.TemporaryDirectory() as channel_contents, \
551 tempfile.TemporaryDirectory() as git_contents:
552
553 extract_tarball(v, channel, channel_contents)
554 check_channel_metadata(v, channel, channel_contents)
555
556 git_checkout(v, channel, git_contents)
557
558 compare_tarball_and_git(v, channel, channel_contents, git_contents)
559
560 v.status('Removing temporary directories')
561 v.ok()
562
563
564 def pin_channel(v: Verification, channel: TarrableSearchPath) -> None:
565 fetch(v, channel)
566 parse_channel(v, channel)
567 fetch_resources(v, channel)
568 ensure_git_rev_available(v, channel)
569 check_channel_contents(v, channel)
570
571
572 def git_revision_name(v: Verification, channel: TarrableSearchPath) -> str:
573 v.status('Getting commit date')
574 process = subprocess.run(['git',
575 '-C',
576 git_cachedir(channel.git_repo),
577 'log',
578 '-n1',
579 '--format=%ct-%h',
580 '--abbrev=11',
581 '--no-show-signature',
582 channel.git_revision],
583 stdout=subprocess.PIPE)
584 v.result(process.returncode == 0 and process.stdout != b'')
585 return '%s-%s' % (os.path.basename(channel.git_repo),
586 process.stdout.decode().strip())
587
588
589 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
590 mapping: Mapping[str, Type[SearchPath]] = {
591 'alias': AliasSearchPath,
592 'channel': ChannelSearchPath,
593 'git': GitSearchPath,
594 }
595 return mapping[conf['type']](**dict(conf.items()))
596
597
598 def read_config(filename: str) -> configparser.ConfigParser:
599 config = configparser.ConfigParser()
600 config.read_file(open(filename), filename)
601 return config
602
603
604 def read_config_files(
605 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
606 merged_config: Dict[str, configparser.SectionProxy] = {}
607 for file in filenames:
608 config = read_config(file)
609 for section in config.sections():
610 if section in merged_config:
611 raise Exception('Duplicate channel "%s"' % section)
612 merged_config[section] = config[section]
613 return merged_config
614
615
616 def pin(args: argparse.Namespace) -> None:
617 v = Verification()
618 config = read_config(args.channels_file)
619 for section in config.sections():
620 if args.channels and section not in args.channels:
621 continue
622
623 sp = read_search_path(config[section])
624
625 config[section].update(sp.pin(v)._asdict())
626
627 with open(args.channels_file, 'w') as configfile:
628 config.write(configfile)
629
630
631 def update(args: argparse.Namespace) -> None:
632 v = Verification()
633 exprs: Dict[str, str] = {}
634 config = read_config_files(args.channels_file)
635 for section in config:
636 sp = read_search_path(config[section])
637 if isinstance(sp, AliasSearchPath):
638 assert 'git_repo' not in config[section]
639 continue
640 tarball = sp.fetch(v, section, config[section])
641 exprs[section] = (
642 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
643 (config[section]['release_name'], tarball))
644
645 for section in config:
646 if 'alias_of' in config[section]:
647 exprs[section] = exprs[str(config[section]['alias_of'])]
648
649 command = [
650 'nix-env',
651 '--profile',
652 '/nix/var/nix/profiles/per-user/%s/channels' %
653 getpass.getuser(),
654 '--show-trace',
655 '--file',
656 '<nix/unpack-channel.nix>',
657 '--install',
658 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
659 if args.dry_run:
660 print(' '.join(map(shlex.quote, command)))
661 else:
662 v.status('Installing channels with nix-env')
663 process = subprocess.run(command)
664 v.result(process.returncode == 0)
665
666
667 def main() -> None:
668 parser = argparse.ArgumentParser(prog='pinch')
669 subparsers = parser.add_subparsers(dest='mode', required=True)
670 parser_pin = subparsers.add_parser('pin')
671 parser_pin.add_argument('channels_file', type=str)
672 parser_pin.add_argument('channels', type=str, nargs='*')
673 parser_pin.set_defaults(func=pin)
674 parser_update = subparsers.add_parser('update')
675 parser_update.add_argument('--dry-run', action='store_true')
676 parser_update.add_argument('channels_file', type=str, nargs='+')
677 parser_update.set_defaults(func=update)
678 args = parser.parse_args()
679 args.func(args)
680
681
682 if __name__ == '__main__':
683 main()