]> git.scottworley.com Git - pinch/blob - pinch.py
58f7a3290cc507760303858cd90ba351f118d0e0
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 NewType,
26 Tuple,
27 )
28
29 # Use xdg module when it's less painful to have as a dependency
30
31
32 class XDG(types.SimpleNamespace):
33 XDG_CACHE_HOME: str
34
35
36 xdg = XDG(
37 XDG_CACHE_HOME=os.getenv(
38 'XDG_CACHE_HOME',
39 os.path.expanduser('~/.cache')))
40
41
42 class VerificationError(Exception):
43 pass
44
45
46 class Verification:
47
48 def __init__(self) -> None:
49 self.line_length = 0
50
51 def status(self, s: str) -> None:
52 print(s, end=' ', file=sys.stderr, flush=True)
53 self.line_length += 1 + len(s) # Unicode??
54
55 @staticmethod
56 def _color(s: str, c: int) -> str:
57 return '\033[%2dm%s\033[00m' % (c, s)
58
59 def result(self, r: bool) -> None:
60 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
61 length = len(message)
62 cols = shutil.get_terminal_size().columns or 80
63 pad = (cols - (self.line_length + length)) % cols
64 print(' ' * pad + self._color(message, color), file=sys.stderr)
65 self.line_length = 0
66 if not r:
67 raise VerificationError()
68
69 def check(self, s: str, r: bool) -> None:
70 self.status(s)
71 self.result(r)
72
73 def ok(self) -> None:
74 self.result(True)
75
76
77 Digest16 = NewType('Digest16', str)
78 Digest32 = NewType('Digest32', str)
79
80
81 class ChannelTableEntry(types.SimpleNamespace):
82 absolute_url: str
83 digest: Digest16
84 file: str
85 size: int
86 url: str
87
88
89 class SearchPath(types.SimpleNamespace, ABC):
90 release_name: str
91
92 @abstractmethod
93 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
94 pass
95
96
97 class AliasSearchPath(SearchPath):
98 alias_of: str
99
100 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
101 assert not hasattr(self, 'git_repo')
102
103
104 # (This lint-disable is for pylint bug https://github.com/PyCQA/pylint/issues/179
105 # which is fixed in pylint 2.5.)
106 class TarrableSearchPath(SearchPath, ABC): # pylint: disable=abstract-method
107 channel_html: bytes
108 channel_url: str
109 forwarded_url: str
110 git_ref: str
111 git_repo: str
112 git_revision: str
113 old_git_revision: str
114 table: Dict[str, ChannelTableEntry]
115
116
117 class GitSearchPath(TarrableSearchPath):
118 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
119 if hasattr(self, 'git_revision'):
120 self.old_git_revision = self.git_revision
121 del self.git_revision
122
123 git_fetch(v, self)
124 conf['release_name'] = git_revision_name(v, self)
125 conf['git_revision'] = self.git_revision
126
127 def fetch(self, v: Verification, section: str,
128 conf: configparser.SectionProxy) -> str:
129 if 'git_repo' not in conf or 'release_name' not in conf:
130 raise Exception(
131 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
132 section)
133
134 ensure_git_rev_available(v, self)
135 return git_get_tarball(v, self)
136
137
138 class ChannelSearchPath(TarrableSearchPath):
139 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
140 if hasattr(self, 'git_revision'):
141 self.old_git_revision = self.git_revision
142 del self.git_revision
143
144 pin_channel(v, self)
145 conf['release_name'] = self.release_name
146 conf['tarball_url'] = self.table['nixexprs.tar.xz'].absolute_url
147 conf['tarball_sha256'] = self.table['nixexprs.tar.xz'].digest
148 conf['git_revision'] = self.git_revision
149
150 # Lint TODO: Put tarball_url and tarball_sha256 in ChannelSearchPath
151 # pylint: disable=no-self-use
152 def fetch(self, v: Verification, section: str,
153 conf: configparser.SectionProxy) -> str:
154 if 'git_repo' not in conf or 'release_name' not in conf:
155 raise Exception(
156 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
157 section)
158
159 return fetch_with_nix_prefetch_url(
160 v, conf['tarball_url'], Digest16(
161 conf['tarball_sha256']))
162
163
164 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
165
166 def throw(error: OSError) -> None:
167 raise error
168
169 def join(x: str, y: str) -> str:
170 return y if x == '.' else os.path.join(x, y)
171
172 def recursive_files(d: str) -> Iterable[str]:
173 all_files: List[str] = []
174 for path, dirs, files in os.walk(d, onerror=throw):
175 rel = os.path.relpath(path, start=d)
176 all_files.extend(join(rel, f) for f in files)
177 for dir_or_link in dirs:
178 if os.path.islink(join(path, dir_or_link)):
179 all_files.append(join(rel, dir_or_link))
180 return all_files
181
182 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
183 return (f for f in files if not f.startswith('.git/'))
184
185 files = functools.reduce(
186 operator.or_, (set(
187 exclude_dot_git(
188 recursive_files(x))) for x in [a, b]))
189 return filecmp.cmpfiles(a, b, files, shallow=False)
190
191
192 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
193 v.status('Fetching channel')
194 request = urllib.request.urlopen(channel.channel_url, timeout=10)
195 channel.channel_html = request.read()
196 channel.forwarded_url = request.geturl()
197 v.result(request.status == 200) # type: ignore # (for old mypy)
198 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
199
200
201 def parse_channel(v: Verification, channel: TarrableSearchPath) -> None:
202 v.status('Parsing channel description as XML')
203 d = xml.dom.minidom.parseString(channel.channel_html)
204 v.ok()
205
206 v.status('Extracting release name:')
207 title_name = d.getElementsByTagName(
208 'title')[0].firstChild.nodeValue.split()[2]
209 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
210 v.status(title_name)
211 v.result(title_name == h1_name)
212 channel.release_name = title_name
213
214 v.status('Extracting git commit:')
215 git_commit_node = d.getElementsByTagName('tt')[0]
216 channel.git_revision = git_commit_node.firstChild.nodeValue
217 v.status(channel.git_revision)
218 v.ok()
219 v.status('Verifying git commit label')
220 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
221
222 v.status('Parsing table')
223 channel.table = {}
224 for row in d.getElementsByTagName('tr')[1:]:
225 name = row.childNodes[0].firstChild.firstChild.nodeValue
226 url = row.childNodes[0].firstChild.getAttribute('href')
227 size = int(row.childNodes[1].firstChild.nodeValue)
228 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
229 channel.table[name] = ChannelTableEntry(
230 url=url, digest=digest, size=size)
231 v.ok()
232
233
234 def digest_string(s: bytes) -> Digest16:
235 return Digest16(hashlib.sha256(s).hexdigest())
236
237
238 def digest_file(filename: str) -> Digest16:
239 hasher = hashlib.sha256()
240 with open(filename, 'rb') as f:
241 # pylint: disable=cell-var-from-loop
242 for block in iter(lambda: f.read(4096), b''):
243 hasher.update(block)
244 return Digest16(hasher.hexdigest())
245
246
247 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
248 v.status('Converting digest to base16')
249 process = subprocess.run(
250 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
251 v.result(process.returncode == 0)
252 return Digest16(process.stdout.decode().strip())
253
254
255 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
256 v.status('Converting digest to base32')
257 process = subprocess.run(
258 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
259 v.result(process.returncode == 0)
260 return Digest32(process.stdout.decode().strip())
261
262
263 def fetch_with_nix_prefetch_url(
264 v: Verification,
265 url: str,
266 digest: Digest16) -> str:
267 v.status('Fetching %s' % url)
268 process = subprocess.run(
269 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
270 v.result(process.returncode == 0)
271 prefetch_digest, path, empty = process.stdout.decode().split('\n')
272 assert empty == ''
273 v.check("Verifying nix-prefetch-url's digest",
274 to_Digest16(v, Digest32(prefetch_digest)) == digest)
275 v.status("Verifying file digest")
276 file_digest = digest_file(path)
277 v.result(file_digest == digest)
278 return path # type: ignore # (for old mypy)
279
280
281 def fetch_resources(v: Verification, channel: TarrableSearchPath) -> None:
282 for resource in ['git-revision', 'nixexprs.tar.xz']:
283 fields = channel.table[resource]
284 fields.absolute_url = urllib.parse.urljoin(
285 channel.forwarded_url, fields.url)
286 fields.file = fetch_with_nix_prefetch_url(
287 v, fields.absolute_url, fields.digest)
288 v.status('Verifying git commit on main page matches git commit in table')
289 v.result(
290 open(
291 channel.table['git-revision'].file).read(999) == channel.git_revision)
292
293
294 def git_cachedir(git_repo: str) -> str:
295 return os.path.join(
296 xdg.XDG_CACHE_HOME,
297 'pinch/git',
298 digest_string(git_repo.encode()))
299
300
301 def tarball_cache_file(channel: TarrableSearchPath) -> str:
302 return os.path.join(
303 xdg.XDG_CACHE_HOME,
304 'pinch/git-tarball',
305 '%s-%s-%s' %
306 (digest_string(channel.git_repo.encode()),
307 channel.git_revision,
308 channel.release_name))
309
310
311 def verify_git_ancestry(v: Verification, channel: TarrableSearchPath) -> None:
312 cachedir = git_cachedir(channel.git_repo)
313 v.status('Verifying rev is an ancestor of ref')
314 process = subprocess.run(['git',
315 '-C',
316 cachedir,
317 'merge-base',
318 '--is-ancestor',
319 channel.git_revision,
320 channel.git_ref])
321 v.result(process.returncode == 0)
322
323 if hasattr(channel, 'old_git_revision'):
324 v.status(
325 'Verifying rev is an ancestor of previous rev %s' %
326 channel.old_git_revision)
327 process = subprocess.run(['git',
328 '-C',
329 cachedir,
330 'merge-base',
331 '--is-ancestor',
332 channel.old_git_revision,
333 channel.git_revision])
334 v.result(process.returncode == 0)
335
336
337 def git_fetch(v: Verification, channel: TarrableSearchPath) -> None:
338 # It would be nice if we could share the nix git cache, but as of the time
339 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
340 # yet), and trying to straddle them both is too far into nix implementation
341 # details for my comfort. So we re-implement here half of nix.fetchGit.
342 # :(
343
344 cachedir = git_cachedir(channel.git_repo)
345 if not os.path.exists(cachedir):
346 v.status("Initializing git repo")
347 process = subprocess.run(
348 ['git', 'init', '--bare', cachedir])
349 v.result(process.returncode == 0)
350
351 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
352 # We don't use --force here because we want to abort and freak out if forced
353 # updates are happening.
354 process = subprocess.run(['git',
355 '-C',
356 cachedir,
357 'fetch',
358 channel.git_repo,
359 '%s:%s' % (channel.git_ref,
360 channel.git_ref)])
361 v.result(process.returncode == 0)
362
363 if hasattr(channel, 'git_revision'):
364 v.status('Verifying that fetch retrieved this rev')
365 process = subprocess.run(
366 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
367 v.result(process.returncode == 0)
368 else:
369 channel.git_revision = open(
370 os.path.join(
371 cachedir,
372 'refs',
373 'heads',
374 channel.git_ref)).read(999).strip()
375
376 verify_git_ancestry(v, channel)
377
378
379 def ensure_git_rev_available(
380 v: Verification,
381 channel: TarrableSearchPath) -> None:
382 cachedir = git_cachedir(channel.git_repo)
383 if os.path.exists(cachedir):
384 v.status('Checking if we already have this rev:')
385 process = subprocess.run(
386 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
387 if process.returncode == 0:
388 v.status('yes')
389 if process.returncode == 1:
390 v.status('no')
391 v.result(process.returncode == 0 or process.returncode == 1)
392 if process.returncode == 0:
393 verify_git_ancestry(v, channel)
394 return
395 git_fetch(v, channel)
396
397
398 def compare_tarball_and_git(
399 v: Verification,
400 channel: TarrableSearchPath,
401 channel_contents: str,
402 git_contents: str) -> None:
403 v.status('Comparing channel tarball with git checkout')
404 match, mismatch, errors = compare(os.path.join(
405 channel_contents, channel.release_name), git_contents)
406 v.ok()
407 v.check('%d files match' % len(match), len(match) > 0)
408 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
409 expected_errors = [
410 '.git-revision',
411 '.version-suffix',
412 'nixpkgs',
413 'programs.sqlite',
414 'svn-revision']
415 benign_errors = []
416 for ee in expected_errors:
417 if ee in errors:
418 errors.remove(ee)
419 benign_errors.append(ee)
420 v.check(
421 '%d unexpected incomparable files' %
422 len(errors),
423 len(errors) == 0)
424 v.check(
425 '(%d of %d expected incomparable files)' %
426 (len(benign_errors),
427 len(expected_errors)),
428 len(benign_errors) == len(expected_errors))
429
430
431 def extract_tarball(
432 v: Verification,
433 channel: TarrableSearchPath,
434 dest: str) -> None:
435 v.status('Extracting tarball %s' %
436 channel.table['nixexprs.tar.xz'].file)
437 shutil.unpack_archive(
438 channel.table['nixexprs.tar.xz'].file,
439 dest)
440 v.ok()
441
442
443 def git_checkout(
444 v: Verification,
445 channel: TarrableSearchPath,
446 dest: str) -> None:
447 v.status('Checking out corresponding git revision')
448 git = subprocess.Popen(['git',
449 '-C',
450 git_cachedir(channel.git_repo),
451 'archive',
452 channel.git_revision],
453 stdout=subprocess.PIPE)
454 tar = subprocess.Popen(
455 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
456 if git.stdout:
457 git.stdout.close()
458 tar.wait()
459 git.wait()
460 v.result(git.returncode == 0 and tar.returncode == 0)
461
462
463 def git_get_tarball(v: Verification, channel: TarrableSearchPath) -> str:
464 cache_file = tarball_cache_file(channel)
465 if os.path.exists(cache_file):
466 cached_tarball = open(cache_file).read(9999)
467 if os.path.exists(cached_tarball):
468 return cached_tarball
469
470 with tempfile.TemporaryDirectory() as output_dir:
471 output_filename = os.path.join(
472 output_dir, channel.release_name + '.tar.xz')
473 with open(output_filename, 'w') as output_file:
474 v.status(
475 'Generating tarball for git revision %s' %
476 channel.git_revision)
477 git = subprocess.Popen(['git',
478 '-C',
479 git_cachedir(channel.git_repo),
480 'archive',
481 '--prefix=%s/' % channel.release_name,
482 channel.git_revision],
483 stdout=subprocess.PIPE)
484 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
485 xz.wait()
486 git.wait()
487 v.result(git.returncode == 0 and xz.returncode == 0)
488
489 v.status('Putting tarball in Nix store')
490 process = subprocess.run(
491 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
492 v.result(process.returncode == 0)
493 store_tarball = process.stdout.decode().strip()
494
495 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
496 open(cache_file, 'w').write(store_tarball)
497 return store_tarball # type: ignore # (for old mypy)
498
499
500 def check_channel_metadata(
501 v: Verification,
502 channel: TarrableSearchPath,
503 channel_contents: str) -> None:
504 v.status('Verifying git commit in channel tarball')
505 v.result(
506 open(
507 os.path.join(
508 channel_contents,
509 channel.release_name,
510 '.git-revision')).read(999) == channel.git_revision)
511
512 v.status(
513 'Verifying version-suffix is a suffix of release name %s:' %
514 channel.release_name)
515 version_suffix = open(
516 os.path.join(
517 channel_contents,
518 channel.release_name,
519 '.version-suffix')).read(999)
520 v.status(version_suffix)
521 v.result(channel.release_name.endswith(version_suffix))
522
523
524 def check_channel_contents(
525 v: Verification,
526 channel: TarrableSearchPath) -> None:
527 with tempfile.TemporaryDirectory() as channel_contents, \
528 tempfile.TemporaryDirectory() as git_contents:
529
530 extract_tarball(v, channel, channel_contents)
531 check_channel_metadata(v, channel, channel_contents)
532
533 git_checkout(v, channel, git_contents)
534
535 compare_tarball_and_git(v, channel, channel_contents, git_contents)
536
537 v.status('Removing temporary directories')
538 v.ok()
539
540
541 def pin_channel(v: Verification, channel: TarrableSearchPath) -> None:
542 fetch(v, channel)
543 parse_channel(v, channel)
544 fetch_resources(v, channel)
545 ensure_git_rev_available(v, channel)
546 check_channel_contents(v, channel)
547
548
549 def git_revision_name(v: Verification, channel: TarrableSearchPath) -> str:
550 v.status('Getting commit date')
551 process = subprocess.run(['git',
552 '-C',
553 git_cachedir(channel.git_repo),
554 'log',
555 '-n1',
556 '--format=%ct-%h',
557 '--abbrev=11',
558 '--no-show-signature',
559 channel.git_revision],
560 stdout=subprocess.PIPE)
561 v.result(process.returncode == 0 and process.stdout != b'')
562 return '%s-%s' % (os.path.basename(channel.git_repo),
563 process.stdout.decode().strip())
564
565
566 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
567 if 'alias_of' in conf:
568 return AliasSearchPath(**dict(conf.items()))
569 if 'channel_url' in conf:
570 return ChannelSearchPath(**dict(conf.items()))
571 return GitSearchPath(**dict(conf.items()))
572
573
574 def read_config(filename: str) -> configparser.ConfigParser:
575 config = configparser.ConfigParser()
576 config.read_file(open(filename), filename)
577 return config
578
579
580 def read_config_files(
581 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
582 merged_config: Dict[str, configparser.SectionProxy] = {}
583 for file in filenames:
584 config = read_config(file)
585 for section in config.sections():
586 if section in merged_config:
587 raise Exception('Duplicate channel "%s"' % section)
588 merged_config[section] = config[section]
589 return merged_config
590
591
592 def pin(args: argparse.Namespace) -> None:
593 v = Verification()
594 config = read_config(args.channels_file)
595 for section in config.sections():
596 if args.channels and section not in args.channels:
597 continue
598
599 sp = read_search_path(config[section])
600
601 sp.pin(v, config[section])
602
603 with open(args.channels_file, 'w') as configfile:
604 config.write(configfile)
605
606
607 def update(args: argparse.Namespace) -> None:
608 v = Verification()
609 exprs: Dict[str, str] = {}
610 config = read_config_files(args.channels_file)
611 for section in config:
612 sp = read_search_path(config[section])
613 if isinstance(sp, AliasSearchPath):
614 assert 'git_repo' not in config[section]
615 continue
616 tarball = sp.fetch(v, section, config[section])
617 exprs[section] = (
618 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
619 (config[section]['release_name'], tarball))
620
621 for section in config:
622 if 'alias_of' in config[section]:
623 exprs[section] = exprs[str(config[section]['alias_of'])]
624
625 command = [
626 'nix-env',
627 '--profile',
628 '/nix/var/nix/profiles/per-user/%s/channels' %
629 getpass.getuser(),
630 '--show-trace',
631 '--file',
632 '<nix/unpack-channel.nix>',
633 '--install',
634 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
635 if args.dry_run:
636 print(' '.join(map(shlex.quote, command)))
637 else:
638 v.status('Installing channels with nix-env')
639 process = subprocess.run(command)
640 v.result(process.returncode == 0)
641
642
643 def main() -> None:
644 parser = argparse.ArgumentParser(prog='pinch')
645 subparsers = parser.add_subparsers(dest='mode', required=True)
646 parser_pin = subparsers.add_parser('pin')
647 parser_pin.add_argument('channels_file', type=str)
648 parser_pin.add_argument('channels', type=str, nargs='*')
649 parser_pin.set_defaults(func=pin)
650 parser_update = subparsers.add_parser('update')
651 parser_update.add_argument('--dry-run', action='store_true')
652 parser_update.add_argument('channels_file', type=str, nargs='+')
653 parser_update.set_defaults(func=update)
654 args = parser.parse_args()
655 args.func(args)
656
657
658 if __name__ == '__main__':
659 main()