]> git.scottworley.com Git - pinch/blob - pinch.py
Keep config mutability at top level
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 NamedTuple,
26 NewType,
27 Tuple,
28 Union,
29 )
30
31 # Use xdg module when it's less painful to have as a dependency
32
33
34 class XDG(types.SimpleNamespace):
35 XDG_CACHE_HOME: str
36
37
38 xdg = XDG(
39 XDG_CACHE_HOME=os.getenv(
40 'XDG_CACHE_HOME',
41 os.path.expanduser('~/.cache')))
42
43
44 class VerificationError(Exception):
45 pass
46
47
48 class Verification:
49
50 def __init__(self) -> None:
51 self.line_length = 0
52
53 def status(self, s: str) -> None:
54 print(s, end=' ', file=sys.stderr, flush=True)
55 self.line_length += 1 + len(s) # Unicode??
56
57 @staticmethod
58 def _color(s: str, c: int) -> str:
59 return '\033[%2dm%s\033[00m' % (c, s)
60
61 def result(self, r: bool) -> None:
62 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
63 length = len(message)
64 cols = shutil.get_terminal_size().columns or 80
65 pad = (cols - (self.line_length + length)) % cols
66 print(' ' * pad + self._color(message, color), file=sys.stderr)
67 self.line_length = 0
68 if not r:
69 raise VerificationError()
70
71 def check(self, s: str, r: bool) -> None:
72 self.status(s)
73 self.result(r)
74
75 def ok(self) -> None:
76 self.result(True)
77
78
79 Digest16 = NewType('Digest16', str)
80 Digest32 = NewType('Digest32', str)
81
82
83 class ChannelTableEntry(types.SimpleNamespace):
84 absolute_url: str
85 digest: Digest16
86 file: str
87 size: int
88 url: str
89
90
91 class AliasPin(NamedTuple):
92 pass
93
94
95 class GitPin(NamedTuple):
96 git_revision: str
97 release_name: str
98
99
100 class ChannelPin(NamedTuple):
101 git_revision: str
102 release_name: str
103 tarball_url: str
104 tarball_sha256: str
105
106
107 Pin = Union[AliasPin, GitPin, ChannelPin]
108
109
110 class SearchPath(types.SimpleNamespace, ABC):
111
112 @abstractmethod
113 def pin(self, v: Verification) -> Pin:
114 pass
115
116
117 class AliasSearchPath(SearchPath):
118 alias_of: str
119
120 def pin(self, v: Verification) -> AliasPin:
121 assert not hasattr(self, 'git_repo')
122 return AliasPin()
123
124
125 # (This lint-disable is for pylint bug https://github.com/PyCQA/pylint/issues/179
126 # which is fixed in pylint 2.5.)
127 class TarrableSearchPath(SearchPath, ABC): # pylint: disable=abstract-method
128 channel_html: bytes
129 channel_url: str
130 forwarded_url: str
131 git_ref: str
132 git_repo: str
133 old_git_revision: str
134 table: Dict[str, ChannelTableEntry]
135
136
137 class GitSearchPath(TarrableSearchPath):
138 def pin(self, v: Verification) -> GitPin:
139 if hasattr(self, 'git_revision'):
140 self.old_git_revision = self.git_revision
141 del self.git_revision
142
143 git_fetch(v, self)
144 return GitPin(release_name=git_revision_name(v, self),
145 git_revision=self.git_revision)
146
147 def fetch(self, v: Verification, section: str,
148 conf: configparser.SectionProxy) -> str:
149 if 'git_repo' not in conf or 'release_name' not in conf:
150 raise Exception(
151 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
152 section)
153
154 ensure_git_rev_available(v, self)
155 return git_get_tarball(v, self)
156
157
158 class ChannelSearchPath(TarrableSearchPath):
159 def pin(self, v: Verification) -> ChannelPin:
160 if hasattr(self, 'git_revision'):
161 self.old_git_revision = self.git_revision
162 del self.git_revision
163
164 pin_channel(v, self)
165 return ChannelPin(
166 release_name=self.release_name,
167 tarball_url=self.table['nixexprs.tar.xz'].absolute_url,
168 tarball_sha256=self.table['nixexprs.tar.xz'].digest,
169 git_revision=self.git_revision)
170
171 # Lint TODO: Put tarball_url and tarball_sha256 in ChannelSearchPath
172 # pylint: disable=no-self-use
173 def fetch(self, v: Verification, section: str,
174 conf: configparser.SectionProxy) -> str:
175 if 'git_repo' not in conf or 'release_name' not in conf:
176 raise Exception(
177 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
178 section)
179
180 return fetch_with_nix_prefetch_url(
181 v, conf['tarball_url'], Digest16(
182 conf['tarball_sha256']))
183
184
185 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
186
187 def throw(error: OSError) -> None:
188 raise error
189
190 def join(x: str, y: str) -> str:
191 return y if x == '.' else os.path.join(x, y)
192
193 def recursive_files(d: str) -> Iterable[str]:
194 all_files: List[str] = []
195 for path, dirs, files in os.walk(d, onerror=throw):
196 rel = os.path.relpath(path, start=d)
197 all_files.extend(join(rel, f) for f in files)
198 for dir_or_link in dirs:
199 if os.path.islink(join(path, dir_or_link)):
200 all_files.append(join(rel, dir_or_link))
201 return all_files
202
203 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
204 return (f for f in files if not f.startswith('.git/'))
205
206 files = functools.reduce(
207 operator.or_, (set(
208 exclude_dot_git(
209 recursive_files(x))) for x in [a, b]))
210 return filecmp.cmpfiles(a, b, files, shallow=False)
211
212
213 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
214 v.status('Fetching channel')
215 request = urllib.request.urlopen(channel.channel_url, timeout=10)
216 channel.channel_html = request.read()
217 channel.forwarded_url = request.geturl()
218 v.result(request.status == 200) # type: ignore # (for old mypy)
219 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
220
221
222 def parse_channel(v: Verification, channel: TarrableSearchPath) -> None:
223 v.status('Parsing channel description as XML')
224 d = xml.dom.minidom.parseString(channel.channel_html)
225 v.ok()
226
227 v.status('Extracting release name:')
228 title_name = d.getElementsByTagName(
229 'title')[0].firstChild.nodeValue.split()[2]
230 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
231 v.status(title_name)
232 v.result(title_name == h1_name)
233 channel.release_name = title_name
234
235 v.status('Extracting git commit:')
236 git_commit_node = d.getElementsByTagName('tt')[0]
237 channel.git_revision = git_commit_node.firstChild.nodeValue
238 v.status(channel.git_revision)
239 v.ok()
240 v.status('Verifying git commit label')
241 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
242
243 v.status('Parsing table')
244 channel.table = {}
245 for row in d.getElementsByTagName('tr')[1:]:
246 name = row.childNodes[0].firstChild.firstChild.nodeValue
247 url = row.childNodes[0].firstChild.getAttribute('href')
248 size = int(row.childNodes[1].firstChild.nodeValue)
249 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
250 channel.table[name] = ChannelTableEntry(
251 url=url, digest=digest, size=size)
252 v.ok()
253
254
255 def digest_string(s: bytes) -> Digest16:
256 return Digest16(hashlib.sha256(s).hexdigest())
257
258
259 def digest_file(filename: str) -> Digest16:
260 hasher = hashlib.sha256()
261 with open(filename, 'rb') as f:
262 # pylint: disable=cell-var-from-loop
263 for block in iter(lambda: f.read(4096), b''):
264 hasher.update(block)
265 return Digest16(hasher.hexdigest())
266
267
268 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
269 v.status('Converting digest to base16')
270 process = subprocess.run(
271 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
272 v.result(process.returncode == 0)
273 return Digest16(process.stdout.decode().strip())
274
275
276 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
277 v.status('Converting digest to base32')
278 process = subprocess.run(
279 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
280 v.result(process.returncode == 0)
281 return Digest32(process.stdout.decode().strip())
282
283
284 def fetch_with_nix_prefetch_url(
285 v: Verification,
286 url: str,
287 digest: Digest16) -> str:
288 v.status('Fetching %s' % url)
289 process = subprocess.run(
290 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
291 v.result(process.returncode == 0)
292 prefetch_digest, path, empty = process.stdout.decode().split('\n')
293 assert empty == ''
294 v.check("Verifying nix-prefetch-url's digest",
295 to_Digest16(v, Digest32(prefetch_digest)) == digest)
296 v.status("Verifying file digest")
297 file_digest = digest_file(path)
298 v.result(file_digest == digest)
299 return path # type: ignore # (for old mypy)
300
301
302 def fetch_resources(v: Verification, channel: TarrableSearchPath) -> None:
303 for resource in ['git-revision', 'nixexprs.tar.xz']:
304 fields = channel.table[resource]
305 fields.absolute_url = urllib.parse.urljoin(
306 channel.forwarded_url, fields.url)
307 fields.file = fetch_with_nix_prefetch_url(
308 v, fields.absolute_url, fields.digest)
309 v.status('Verifying git commit on main page matches git commit in table')
310 v.result(
311 open(
312 channel.table['git-revision'].file).read(999) == channel.git_revision)
313
314
315 def git_cachedir(git_repo: str) -> str:
316 return os.path.join(
317 xdg.XDG_CACHE_HOME,
318 'pinch/git',
319 digest_string(git_repo.encode()))
320
321
322 def tarball_cache_file(channel: TarrableSearchPath) -> str:
323 return os.path.join(
324 xdg.XDG_CACHE_HOME,
325 'pinch/git-tarball',
326 '%s-%s-%s' %
327 (digest_string(channel.git_repo.encode()),
328 channel.git_revision,
329 channel.release_name))
330
331
332 def verify_git_ancestry(v: Verification, channel: TarrableSearchPath) -> None:
333 cachedir = git_cachedir(channel.git_repo)
334 v.status('Verifying rev is an ancestor of ref')
335 process = subprocess.run(['git',
336 '-C',
337 cachedir,
338 'merge-base',
339 '--is-ancestor',
340 channel.git_revision,
341 channel.git_ref])
342 v.result(process.returncode == 0)
343
344 if hasattr(channel, 'old_git_revision'):
345 v.status(
346 'Verifying rev is an ancestor of previous rev %s' %
347 channel.old_git_revision)
348 process = subprocess.run(['git',
349 '-C',
350 cachedir,
351 'merge-base',
352 '--is-ancestor',
353 channel.old_git_revision,
354 channel.git_revision])
355 v.result(process.returncode == 0)
356
357
358 def git_fetch(v: Verification, channel: TarrableSearchPath) -> None:
359 # It would be nice if we could share the nix git cache, but as of the time
360 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
361 # yet), and trying to straddle them both is too far into nix implementation
362 # details for my comfort. So we re-implement here half of nix.fetchGit.
363 # :(
364
365 cachedir = git_cachedir(channel.git_repo)
366 if not os.path.exists(cachedir):
367 v.status("Initializing git repo")
368 process = subprocess.run(
369 ['git', 'init', '--bare', cachedir])
370 v.result(process.returncode == 0)
371
372 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
373 # We don't use --force here because we want to abort and freak out if forced
374 # updates are happening.
375 process = subprocess.run(['git',
376 '-C',
377 cachedir,
378 'fetch',
379 channel.git_repo,
380 '%s:%s' % (channel.git_ref,
381 channel.git_ref)])
382 v.result(process.returncode == 0)
383
384 if hasattr(channel, 'git_revision'):
385 v.status('Verifying that fetch retrieved this rev')
386 process = subprocess.run(
387 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
388 v.result(process.returncode == 0)
389 else:
390 channel.git_revision = open(
391 os.path.join(
392 cachedir,
393 'refs',
394 'heads',
395 channel.git_ref)).read(999).strip()
396
397 verify_git_ancestry(v, channel)
398
399
400 def ensure_git_rev_available(
401 v: Verification,
402 channel: TarrableSearchPath) -> None:
403 cachedir = git_cachedir(channel.git_repo)
404 if os.path.exists(cachedir):
405 v.status('Checking if we already have this rev:')
406 process = subprocess.run(
407 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
408 if process.returncode == 0:
409 v.status('yes')
410 if process.returncode == 1:
411 v.status('no')
412 v.result(process.returncode == 0 or process.returncode == 1)
413 if process.returncode == 0:
414 verify_git_ancestry(v, channel)
415 return
416 git_fetch(v, channel)
417
418
419 def compare_tarball_and_git(
420 v: Verification,
421 channel: TarrableSearchPath,
422 channel_contents: str,
423 git_contents: str) -> None:
424 v.status('Comparing channel tarball with git checkout')
425 match, mismatch, errors = compare(os.path.join(
426 channel_contents, channel.release_name), git_contents)
427 v.ok()
428 v.check('%d files match' % len(match), len(match) > 0)
429 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
430 expected_errors = [
431 '.git-revision',
432 '.version-suffix',
433 'nixpkgs',
434 'programs.sqlite',
435 'svn-revision']
436 benign_errors = []
437 for ee in expected_errors:
438 if ee in errors:
439 errors.remove(ee)
440 benign_errors.append(ee)
441 v.check(
442 '%d unexpected incomparable files' %
443 len(errors),
444 len(errors) == 0)
445 v.check(
446 '(%d of %d expected incomparable files)' %
447 (len(benign_errors),
448 len(expected_errors)),
449 len(benign_errors) == len(expected_errors))
450
451
452 def extract_tarball(
453 v: Verification,
454 channel: TarrableSearchPath,
455 dest: str) -> None:
456 v.status('Extracting tarball %s' %
457 channel.table['nixexprs.tar.xz'].file)
458 shutil.unpack_archive(
459 channel.table['nixexprs.tar.xz'].file,
460 dest)
461 v.ok()
462
463
464 def git_checkout(
465 v: Verification,
466 channel: TarrableSearchPath,
467 dest: str) -> None:
468 v.status('Checking out corresponding git revision')
469 git = subprocess.Popen(['git',
470 '-C',
471 git_cachedir(channel.git_repo),
472 'archive',
473 channel.git_revision],
474 stdout=subprocess.PIPE)
475 tar = subprocess.Popen(
476 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
477 if git.stdout:
478 git.stdout.close()
479 tar.wait()
480 git.wait()
481 v.result(git.returncode == 0 and tar.returncode == 0)
482
483
484 def git_get_tarball(v: Verification, channel: TarrableSearchPath) -> str:
485 cache_file = tarball_cache_file(channel)
486 if os.path.exists(cache_file):
487 cached_tarball = open(cache_file).read(9999)
488 if os.path.exists(cached_tarball):
489 return cached_tarball
490
491 with tempfile.TemporaryDirectory() as output_dir:
492 output_filename = os.path.join(
493 output_dir, channel.release_name + '.tar.xz')
494 with open(output_filename, 'w') as output_file:
495 v.status(
496 'Generating tarball for git revision %s' %
497 channel.git_revision)
498 git = subprocess.Popen(['git',
499 '-C',
500 git_cachedir(channel.git_repo),
501 'archive',
502 '--prefix=%s/' % channel.release_name,
503 channel.git_revision],
504 stdout=subprocess.PIPE)
505 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
506 xz.wait()
507 git.wait()
508 v.result(git.returncode == 0 and xz.returncode == 0)
509
510 v.status('Putting tarball in Nix store')
511 process = subprocess.run(
512 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
513 v.result(process.returncode == 0)
514 store_tarball = process.stdout.decode().strip()
515
516 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
517 open(cache_file, 'w').write(store_tarball)
518 return store_tarball # type: ignore # (for old mypy)
519
520
521 def check_channel_metadata(
522 v: Verification,
523 channel: TarrableSearchPath,
524 channel_contents: str) -> None:
525 v.status('Verifying git commit in channel tarball')
526 v.result(
527 open(
528 os.path.join(
529 channel_contents,
530 channel.release_name,
531 '.git-revision')).read(999) == channel.git_revision)
532
533 v.status(
534 'Verifying version-suffix is a suffix of release name %s:' %
535 channel.release_name)
536 version_suffix = open(
537 os.path.join(
538 channel_contents,
539 channel.release_name,
540 '.version-suffix')).read(999)
541 v.status(version_suffix)
542 v.result(channel.release_name.endswith(version_suffix))
543
544
545 def check_channel_contents(
546 v: Verification,
547 channel: TarrableSearchPath) -> None:
548 with tempfile.TemporaryDirectory() as channel_contents, \
549 tempfile.TemporaryDirectory() as git_contents:
550
551 extract_tarball(v, channel, channel_contents)
552 check_channel_metadata(v, channel, channel_contents)
553
554 git_checkout(v, channel, git_contents)
555
556 compare_tarball_and_git(v, channel, channel_contents, git_contents)
557
558 v.status('Removing temporary directories')
559 v.ok()
560
561
562 def pin_channel(v: Verification, channel: TarrableSearchPath) -> None:
563 fetch(v, channel)
564 parse_channel(v, channel)
565 fetch_resources(v, channel)
566 ensure_git_rev_available(v, channel)
567 check_channel_contents(v, channel)
568
569
570 def git_revision_name(v: Verification, channel: TarrableSearchPath) -> str:
571 v.status('Getting commit date')
572 process = subprocess.run(['git',
573 '-C',
574 git_cachedir(channel.git_repo),
575 'log',
576 '-n1',
577 '--format=%ct-%h',
578 '--abbrev=11',
579 '--no-show-signature',
580 channel.git_revision],
581 stdout=subprocess.PIPE)
582 v.result(process.returncode == 0 and process.stdout != b'')
583 return '%s-%s' % (os.path.basename(channel.git_repo),
584 process.stdout.decode().strip())
585
586
587 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
588 if 'alias_of' in conf:
589 return AliasSearchPath(**dict(conf.items()))
590 if 'channel_url' in conf:
591 return ChannelSearchPath(**dict(conf.items()))
592 return GitSearchPath(**dict(conf.items()))
593
594
595 def read_config(filename: str) -> configparser.ConfigParser:
596 config = configparser.ConfigParser()
597 config.read_file(open(filename), filename)
598 return config
599
600
601 def read_config_files(
602 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
603 merged_config: Dict[str, configparser.SectionProxy] = {}
604 for file in filenames:
605 config = read_config(file)
606 for section in config.sections():
607 if section in merged_config:
608 raise Exception('Duplicate channel "%s"' % section)
609 merged_config[section] = config[section]
610 return merged_config
611
612
613 def pin(args: argparse.Namespace) -> None:
614 v = Verification()
615 config = read_config(args.channels_file)
616 for section in config.sections():
617 if args.channels and section not in args.channels:
618 continue
619
620 sp = read_search_path(config[section])
621
622 config[section].update(sp.pin(v)._asdict())
623
624 with open(args.channels_file, 'w') as configfile:
625 config.write(configfile)
626
627
628 def update(args: argparse.Namespace) -> None:
629 v = Verification()
630 exprs: Dict[str, str] = {}
631 config = read_config_files(args.channels_file)
632 for section in config:
633 sp = read_search_path(config[section])
634 if isinstance(sp, AliasSearchPath):
635 assert 'git_repo' not in config[section]
636 continue
637 tarball = sp.fetch(v, section, config[section])
638 exprs[section] = (
639 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
640 (config[section]['release_name'], tarball))
641
642 for section in config:
643 if 'alias_of' in config[section]:
644 exprs[section] = exprs[str(config[section]['alias_of'])]
645
646 command = [
647 'nix-env',
648 '--profile',
649 '/nix/var/nix/profiles/per-user/%s/channels' %
650 getpass.getuser(),
651 '--show-trace',
652 '--file',
653 '<nix/unpack-channel.nix>',
654 '--install',
655 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
656 if args.dry_run:
657 print(' '.join(map(shlex.quote, command)))
658 else:
659 v.status('Installing channels with nix-env')
660 process = subprocess.run(command)
661 v.result(process.returncode == 0)
662
663
664 def main() -> None:
665 parser = argparse.ArgumentParser(prog='pinch')
666 subparsers = parser.add_subparsers(dest='mode', required=True)
667 parser_pin = subparsers.add_parser('pin')
668 parser_pin.add_argument('channels_file', type=str)
669 parser_pin.add_argument('channels', type=str, nargs='*')
670 parser_pin.set_defaults(func=pin)
671 parser_update = subparsers.add_parser('update')
672 parser_update.add_argument('--dry-run', action='store_true')
673 parser_update.add_argument('channels_file', type=str, nargs='+')
674 parser_update.set_defaults(func=update)
675 args = parser.parse_args()
676 args.func(args)
677
678
679 if __name__ == '__main__':
680 main()