]> git.scottworley.com Git - pinch/blob - pinch.py
Polymorphic pin()
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 NewType,
26 Tuple,
27 )
28
29 # Use xdg module when it's less painful to have as a dependency
30
31
32 class XDG(types.SimpleNamespace):
33 XDG_CACHE_HOME: str
34
35
36 xdg = XDG(
37 XDG_CACHE_HOME=os.getenv(
38 'XDG_CACHE_HOME',
39 os.path.expanduser('~/.cache')))
40
41
42 class VerificationError(Exception):
43 pass
44
45
46 class Verification:
47
48 def __init__(self) -> None:
49 self.line_length = 0
50
51 def status(self, s: str) -> None:
52 print(s, end=' ', file=sys.stderr, flush=True)
53 self.line_length += 1 + len(s) # Unicode??
54
55 @staticmethod
56 def _color(s: str, c: int) -> str:
57 return '\033[%2dm%s\033[00m' % (c, s)
58
59 def result(self, r: bool) -> None:
60 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
61 length = len(message)
62 cols = shutil.get_terminal_size().columns or 80
63 pad = (cols - (self.line_length + length)) % cols
64 print(' ' * pad + self._color(message, color), file=sys.stderr)
65 self.line_length = 0
66 if not r:
67 raise VerificationError()
68
69 def check(self, s: str, r: bool) -> None:
70 self.status(s)
71 self.result(r)
72
73 def ok(self) -> None:
74 self.result(True)
75
76
77 Digest16 = NewType('Digest16', str)
78 Digest32 = NewType('Digest32', str)
79
80
81 class ChannelTableEntry(types.SimpleNamespace):
82 absolute_url: str
83 digest: Digest16
84 file: str
85 size: int
86 url: str
87
88
89 class SearchPath(types.SimpleNamespace, ABC):
90 release_name: str
91
92 @abstractmethod
93 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
94 pass
95
96
97 class AliasSearchPath(SearchPath):
98 alias_of: str
99
100 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
101 assert not hasattr(self, 'git_repo')
102
103
104 # (This lint-disable is for pylint bug https://github.com/PyCQA/pylint/issues/179
105 # which is fixed in pylint 2.5.)
106 class TarrableSearchPath(SearchPath, ABC): # pylint: disable=abstract-method
107 channel_html: bytes
108 channel_url: str
109 forwarded_url: str
110 git_ref: str
111 git_repo: str
112 git_revision: str
113 old_git_revision: str
114 table: Dict[str, ChannelTableEntry]
115
116 def fetch(self, v: Verification, section: str,
117 conf: configparser.SectionProxy) -> str:
118 if 'git_repo' not in conf or 'release_name' not in conf:
119 raise Exception(
120 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
121 section)
122
123 if 'channel_url' in conf:
124 return fetch_with_nix_prefetch_url(
125 v, conf['tarball_url'], Digest16(
126 conf['tarball_sha256']))
127
128 ensure_git_rev_available(v, self)
129 return git_get_tarball(v, self)
130
131
132 class GitSearchPath(TarrableSearchPath):
133 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
134 if hasattr(self, 'git_revision'):
135 self.old_git_revision = self.git_revision
136 del self.git_revision
137
138 git_fetch(v, self)
139 conf['release_name'] = git_revision_name(v, self)
140 conf['git_revision'] = self.git_revision
141
142
143 class ChannelSearchPath(TarrableSearchPath):
144 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
145 if hasattr(self, 'git_revision'):
146 self.old_git_revision = self.git_revision
147 del self.git_revision
148
149 pin_channel(v, self)
150 conf['release_name'] = self.release_name
151 conf['tarball_url'] = self.table['nixexprs.tar.xz'].absolute_url
152 conf['tarball_sha256'] = self.table['nixexprs.tar.xz'].digest
153 conf['git_revision'] = self.git_revision
154
155
156 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
157
158 def throw(error: OSError) -> None:
159 raise error
160
161 def join(x: str, y: str) -> str:
162 return y if x == '.' else os.path.join(x, y)
163
164 def recursive_files(d: str) -> Iterable[str]:
165 all_files: List[str] = []
166 for path, dirs, files in os.walk(d, onerror=throw):
167 rel = os.path.relpath(path, start=d)
168 all_files.extend(join(rel, f) for f in files)
169 for dir_or_link in dirs:
170 if os.path.islink(join(path, dir_or_link)):
171 all_files.append(join(rel, dir_or_link))
172 return all_files
173
174 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
175 return (f for f in files if not f.startswith('.git/'))
176
177 files = functools.reduce(
178 operator.or_, (set(
179 exclude_dot_git(
180 recursive_files(x))) for x in [a, b]))
181 return filecmp.cmpfiles(a, b, files, shallow=False)
182
183
184 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
185 v.status('Fetching channel')
186 request = urllib.request.urlopen(channel.channel_url, timeout=10)
187 channel.channel_html = request.read()
188 channel.forwarded_url = request.geturl()
189 v.result(request.status == 200) # type: ignore # (for old mypy)
190 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
191
192
193 def parse_channel(v: Verification, channel: TarrableSearchPath) -> None:
194 v.status('Parsing channel description as XML')
195 d = xml.dom.minidom.parseString(channel.channel_html)
196 v.ok()
197
198 v.status('Extracting release name:')
199 title_name = d.getElementsByTagName(
200 'title')[0].firstChild.nodeValue.split()[2]
201 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
202 v.status(title_name)
203 v.result(title_name == h1_name)
204 channel.release_name = title_name
205
206 v.status('Extracting git commit:')
207 git_commit_node = d.getElementsByTagName('tt')[0]
208 channel.git_revision = git_commit_node.firstChild.nodeValue
209 v.status(channel.git_revision)
210 v.ok()
211 v.status('Verifying git commit label')
212 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
213
214 v.status('Parsing table')
215 channel.table = {}
216 for row in d.getElementsByTagName('tr')[1:]:
217 name = row.childNodes[0].firstChild.firstChild.nodeValue
218 url = row.childNodes[0].firstChild.getAttribute('href')
219 size = int(row.childNodes[1].firstChild.nodeValue)
220 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
221 channel.table[name] = ChannelTableEntry(
222 url=url, digest=digest, size=size)
223 v.ok()
224
225
226 def digest_string(s: bytes) -> Digest16:
227 return Digest16(hashlib.sha256(s).hexdigest())
228
229
230 def digest_file(filename: str) -> Digest16:
231 hasher = hashlib.sha256()
232 with open(filename, 'rb') as f:
233 # pylint: disable=cell-var-from-loop
234 for block in iter(lambda: f.read(4096), b''):
235 hasher.update(block)
236 return Digest16(hasher.hexdigest())
237
238
239 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
240 v.status('Converting digest to base16')
241 process = subprocess.run(
242 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
243 v.result(process.returncode == 0)
244 return Digest16(process.stdout.decode().strip())
245
246
247 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
248 v.status('Converting digest to base32')
249 process = subprocess.run(
250 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
251 v.result(process.returncode == 0)
252 return Digest32(process.stdout.decode().strip())
253
254
255 def fetch_with_nix_prefetch_url(
256 v: Verification,
257 url: str,
258 digest: Digest16) -> str:
259 v.status('Fetching %s' % url)
260 process = subprocess.run(
261 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
262 v.result(process.returncode == 0)
263 prefetch_digest, path, empty = process.stdout.decode().split('\n')
264 assert empty == ''
265 v.check("Verifying nix-prefetch-url's digest",
266 to_Digest16(v, Digest32(prefetch_digest)) == digest)
267 v.status("Verifying file digest")
268 file_digest = digest_file(path)
269 v.result(file_digest == digest)
270 return path # type: ignore # (for old mypy)
271
272
273 def fetch_resources(v: Verification, channel: TarrableSearchPath) -> None:
274 for resource in ['git-revision', 'nixexprs.tar.xz']:
275 fields = channel.table[resource]
276 fields.absolute_url = urllib.parse.urljoin(
277 channel.forwarded_url, fields.url)
278 fields.file = fetch_with_nix_prefetch_url(
279 v, fields.absolute_url, fields.digest)
280 v.status('Verifying git commit on main page matches git commit in table')
281 v.result(
282 open(
283 channel.table['git-revision'].file).read(999) == channel.git_revision)
284
285
286 def git_cachedir(git_repo: str) -> str:
287 return os.path.join(
288 xdg.XDG_CACHE_HOME,
289 'pinch/git',
290 digest_string(git_repo.encode()))
291
292
293 def tarball_cache_file(channel: TarrableSearchPath) -> str:
294 return os.path.join(
295 xdg.XDG_CACHE_HOME,
296 'pinch/git-tarball',
297 '%s-%s-%s' %
298 (digest_string(channel.git_repo.encode()),
299 channel.git_revision,
300 channel.release_name))
301
302
303 def verify_git_ancestry(v: Verification, channel: TarrableSearchPath) -> None:
304 cachedir = git_cachedir(channel.git_repo)
305 v.status('Verifying rev is an ancestor of ref')
306 process = subprocess.run(['git',
307 '-C',
308 cachedir,
309 'merge-base',
310 '--is-ancestor',
311 channel.git_revision,
312 channel.git_ref])
313 v.result(process.returncode == 0)
314
315 if hasattr(channel, 'old_git_revision'):
316 v.status(
317 'Verifying rev is an ancestor of previous rev %s' %
318 channel.old_git_revision)
319 process = subprocess.run(['git',
320 '-C',
321 cachedir,
322 'merge-base',
323 '--is-ancestor',
324 channel.old_git_revision,
325 channel.git_revision])
326 v.result(process.returncode == 0)
327
328
329 def git_fetch(v: Verification, channel: TarrableSearchPath) -> None:
330 # It would be nice if we could share the nix git cache, but as of the time
331 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
332 # yet), and trying to straddle them both is too far into nix implementation
333 # details for my comfort. So we re-implement here half of nix.fetchGit.
334 # :(
335
336 cachedir = git_cachedir(channel.git_repo)
337 if not os.path.exists(cachedir):
338 v.status("Initializing git repo")
339 process = subprocess.run(
340 ['git', 'init', '--bare', cachedir])
341 v.result(process.returncode == 0)
342
343 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
344 # We don't use --force here because we want to abort and freak out if forced
345 # updates are happening.
346 process = subprocess.run(['git',
347 '-C',
348 cachedir,
349 'fetch',
350 channel.git_repo,
351 '%s:%s' % (channel.git_ref,
352 channel.git_ref)])
353 v.result(process.returncode == 0)
354
355 if hasattr(channel, 'git_revision'):
356 v.status('Verifying that fetch retrieved this rev')
357 process = subprocess.run(
358 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
359 v.result(process.returncode == 0)
360 else:
361 channel.git_revision = open(
362 os.path.join(
363 cachedir,
364 'refs',
365 'heads',
366 channel.git_ref)).read(999).strip()
367
368 verify_git_ancestry(v, channel)
369
370
371 def ensure_git_rev_available(
372 v: Verification,
373 channel: TarrableSearchPath) -> None:
374 cachedir = git_cachedir(channel.git_repo)
375 if os.path.exists(cachedir):
376 v.status('Checking if we already have this rev:')
377 process = subprocess.run(
378 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
379 if process.returncode == 0:
380 v.status('yes')
381 if process.returncode == 1:
382 v.status('no')
383 v.result(process.returncode == 0 or process.returncode == 1)
384 if process.returncode == 0:
385 verify_git_ancestry(v, channel)
386 return
387 git_fetch(v, channel)
388
389
390 def compare_tarball_and_git(
391 v: Verification,
392 channel: TarrableSearchPath,
393 channel_contents: str,
394 git_contents: str) -> None:
395 v.status('Comparing channel tarball with git checkout')
396 match, mismatch, errors = compare(os.path.join(
397 channel_contents, channel.release_name), git_contents)
398 v.ok()
399 v.check('%d files match' % len(match), len(match) > 0)
400 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
401 expected_errors = [
402 '.git-revision',
403 '.version-suffix',
404 'nixpkgs',
405 'programs.sqlite',
406 'svn-revision']
407 benign_errors = []
408 for ee in expected_errors:
409 if ee in errors:
410 errors.remove(ee)
411 benign_errors.append(ee)
412 v.check(
413 '%d unexpected incomparable files' %
414 len(errors),
415 len(errors) == 0)
416 v.check(
417 '(%d of %d expected incomparable files)' %
418 (len(benign_errors),
419 len(expected_errors)),
420 len(benign_errors) == len(expected_errors))
421
422
423 def extract_tarball(
424 v: Verification,
425 channel: TarrableSearchPath,
426 dest: str) -> None:
427 v.status('Extracting tarball %s' %
428 channel.table['nixexprs.tar.xz'].file)
429 shutil.unpack_archive(
430 channel.table['nixexprs.tar.xz'].file,
431 dest)
432 v.ok()
433
434
435 def git_checkout(
436 v: Verification,
437 channel: TarrableSearchPath,
438 dest: str) -> None:
439 v.status('Checking out corresponding git revision')
440 git = subprocess.Popen(['git',
441 '-C',
442 git_cachedir(channel.git_repo),
443 'archive',
444 channel.git_revision],
445 stdout=subprocess.PIPE)
446 tar = subprocess.Popen(
447 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
448 if git.stdout:
449 git.stdout.close()
450 tar.wait()
451 git.wait()
452 v.result(git.returncode == 0 and tar.returncode == 0)
453
454
455 def git_get_tarball(v: Verification, channel: TarrableSearchPath) -> str:
456 cache_file = tarball_cache_file(channel)
457 if os.path.exists(cache_file):
458 cached_tarball = open(cache_file).read(9999)
459 if os.path.exists(cached_tarball):
460 return cached_tarball
461
462 with tempfile.TemporaryDirectory() as output_dir:
463 output_filename = os.path.join(
464 output_dir, channel.release_name + '.tar.xz')
465 with open(output_filename, 'w') as output_file:
466 v.status(
467 'Generating tarball for git revision %s' %
468 channel.git_revision)
469 git = subprocess.Popen(['git',
470 '-C',
471 git_cachedir(channel.git_repo),
472 'archive',
473 '--prefix=%s/' % channel.release_name,
474 channel.git_revision],
475 stdout=subprocess.PIPE)
476 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
477 xz.wait()
478 git.wait()
479 v.result(git.returncode == 0 and xz.returncode == 0)
480
481 v.status('Putting tarball in Nix store')
482 process = subprocess.run(
483 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
484 v.result(process.returncode == 0)
485 store_tarball = process.stdout.decode().strip()
486
487 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
488 open(cache_file, 'w').write(store_tarball)
489 return store_tarball # type: ignore # (for old mypy)
490
491
492 def check_channel_metadata(
493 v: Verification,
494 channel: TarrableSearchPath,
495 channel_contents: str) -> None:
496 v.status('Verifying git commit in channel tarball')
497 v.result(
498 open(
499 os.path.join(
500 channel_contents,
501 channel.release_name,
502 '.git-revision')).read(999) == channel.git_revision)
503
504 v.status(
505 'Verifying version-suffix is a suffix of release name %s:' %
506 channel.release_name)
507 version_suffix = open(
508 os.path.join(
509 channel_contents,
510 channel.release_name,
511 '.version-suffix')).read(999)
512 v.status(version_suffix)
513 v.result(channel.release_name.endswith(version_suffix))
514
515
516 def check_channel_contents(
517 v: Verification,
518 channel: TarrableSearchPath) -> None:
519 with tempfile.TemporaryDirectory() as channel_contents, \
520 tempfile.TemporaryDirectory() as git_contents:
521
522 extract_tarball(v, channel, channel_contents)
523 check_channel_metadata(v, channel, channel_contents)
524
525 git_checkout(v, channel, git_contents)
526
527 compare_tarball_and_git(v, channel, channel_contents, git_contents)
528
529 v.status('Removing temporary directories')
530 v.ok()
531
532
533 def pin_channel(v: Verification, channel: TarrableSearchPath) -> None:
534 fetch(v, channel)
535 parse_channel(v, channel)
536 fetch_resources(v, channel)
537 ensure_git_rev_available(v, channel)
538 check_channel_contents(v, channel)
539
540
541 def git_revision_name(v: Verification, channel: TarrableSearchPath) -> str:
542 v.status('Getting commit date')
543 process = subprocess.run(['git',
544 '-C',
545 git_cachedir(channel.git_repo),
546 'log',
547 '-n1',
548 '--format=%ct-%h',
549 '--abbrev=11',
550 '--no-show-signature',
551 channel.git_revision],
552 stdout=subprocess.PIPE)
553 v.result(process.returncode == 0 and process.stdout != b'')
554 return '%s-%s' % (os.path.basename(channel.git_repo),
555 process.stdout.decode().strip())
556
557
558 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
559 if 'alias_of' in conf:
560 return AliasSearchPath(**dict(conf.items()))
561 if 'channel_url' in conf:
562 return ChannelSearchPath(**dict(conf.items()))
563 return GitSearchPath(**dict(conf.items()))
564
565
566 def read_config(filename: str) -> configparser.ConfigParser:
567 config = configparser.ConfigParser()
568 config.read_file(open(filename), filename)
569 return config
570
571
572 def read_config_files(
573 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
574 merged_config: Dict[str, configparser.SectionProxy] = {}
575 for file in filenames:
576 config = read_config(file)
577 for section in config.sections():
578 if section in merged_config:
579 raise Exception('Duplicate channel "%s"' % section)
580 merged_config[section] = config[section]
581 return merged_config
582
583
584 def pin(args: argparse.Namespace) -> None:
585 v = Verification()
586 config = read_config(args.channels_file)
587 for section in config.sections():
588 if args.channels and section not in args.channels:
589 continue
590
591 sp = read_search_path(config[section])
592
593 sp.pin(v, config[section])
594
595 with open(args.channels_file, 'w') as configfile:
596 config.write(configfile)
597
598
599 def update(args: argparse.Namespace) -> None:
600 v = Verification()
601 exprs: Dict[str, str] = {}
602 config = read_config_files(args.channels_file)
603 for section in config:
604 sp = read_search_path(config[section])
605 if isinstance(sp, AliasSearchPath):
606 assert 'git_repo' not in config[section]
607 continue
608 tarball = sp.fetch(v, section, config[section])
609 exprs[section] = (
610 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
611 (config[section]['release_name'], tarball))
612
613 for section in config:
614 if 'alias_of' in config[section]:
615 exprs[section] = exprs[str(config[section]['alias_of'])]
616
617 command = [
618 'nix-env',
619 '--profile',
620 '/nix/var/nix/profiles/per-user/%s/channels' %
621 getpass.getuser(),
622 '--show-trace',
623 '--file',
624 '<nix/unpack-channel.nix>',
625 '--install',
626 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
627 if args.dry_run:
628 print(' '.join(map(shlex.quote, command)))
629 else:
630 v.status('Installing channels with nix-env')
631 process = subprocess.run(command)
632 v.result(process.returncode == 0)
633
634
635 def main() -> None:
636 parser = argparse.ArgumentParser(prog='pinch')
637 subparsers = parser.add_subparsers(dest='mode', required=True)
638 parser_pin = subparsers.add_parser('pin')
639 parser_pin.add_argument('channels_file', type=str)
640 parser_pin.add_argument('channels', type=str, nargs='*')
641 parser_pin.set_defaults(func=pin)
642 parser_update = subparsers.add_parser('update')
643 parser_update.add_argument('--dry-run', action='store_true')
644 parser_update.add_argument('channels_file', type=str, nargs='+')
645 parser_update.set_defaults(func=update)
646 args = parser.parse_args()
647 args.func(args)
648
649
650 if __name__ == '__main__':
651 main()