]> git.scottworley.com Git - pinch/blob - pinch.py
4c2341f6f03d772525c5e6531fda657ff5458fd9
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 Mapping,
26 NamedTuple,
27 NewType,
28 Tuple,
29 Type,
30 Union,
31 )
32
33 # Use xdg module when it's less painful to have as a dependency
34
35
36 class XDG(types.SimpleNamespace):
37 XDG_CACHE_HOME: str
38
39
40 xdg = XDG(
41 XDG_CACHE_HOME=os.getenv(
42 'XDG_CACHE_HOME',
43 os.path.expanduser('~/.cache')))
44
45
46 class VerificationError(Exception):
47 pass
48
49
50 class Verification:
51
52 def __init__(self) -> None:
53 self.line_length = 0
54
55 def status(self, s: str) -> None:
56 print(s, end=' ', file=sys.stderr, flush=True)
57 self.line_length += 1 + len(s) # Unicode??
58
59 @staticmethod
60 def _color(s: str, c: int) -> str:
61 return '\033[%2dm%s\033[00m' % (c, s)
62
63 def result(self, r: bool) -> None:
64 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
65 length = len(message)
66 cols = shutil.get_terminal_size().columns or 80
67 pad = (cols - (self.line_length + length)) % cols
68 print(' ' * pad + self._color(message, color), file=sys.stderr)
69 self.line_length = 0
70 if not r:
71 raise VerificationError()
72
73 def check(self, s: str, r: bool) -> None:
74 self.status(s)
75 self.result(r)
76
77 def ok(self) -> None:
78 self.result(True)
79
80
81 Digest16 = NewType('Digest16', str)
82 Digest32 = NewType('Digest32', str)
83
84
85 class ChannelTableEntry(types.SimpleNamespace):
86 absolute_url: str
87 digest: Digest16
88 file: str
89 size: int
90 url: str
91
92
93 class AliasPin(NamedTuple):
94 pass
95
96
97 class GitPin(NamedTuple):
98 git_revision: str
99 release_name: str
100
101
102 class ChannelPin(NamedTuple):
103 git_revision: str
104 release_name: str
105 tarball_url: str
106 tarball_sha256: str
107
108
109 Pin = Union[AliasPin, GitPin, ChannelPin]
110
111
112 class SearchPath(types.SimpleNamespace, ABC):
113
114 @abstractmethod
115 def pin(self, v: Verification) -> Pin:
116 pass
117
118
119 class AliasSearchPath(SearchPath):
120 alias_of: str
121
122 def pin(self, v: Verification) -> AliasPin:
123 assert not hasattr(self, 'git_repo')
124 return AliasPin()
125
126
127 # (This lint-disable is for pylint bug https://github.com/PyCQA/pylint/issues/179
128 # which is fixed in pylint 2.5.)
129 class TarrableSearchPath(SearchPath, ABC): # pylint: disable=abstract-method
130 channel_html: bytes
131 channel_url: str
132 forwarded_url: str
133 git_ref: str
134 git_repo: str
135 old_git_revision: str
136 table: Dict[str, ChannelTableEntry]
137
138
139 class GitSearchPath(TarrableSearchPath):
140 def pin(self, v: Verification) -> GitPin:
141 if hasattr(self, 'git_revision'):
142 self.old_git_revision = self.git_revision
143 del self.git_revision
144
145 git_fetch(v, self)
146 return GitPin(release_name=git_revision_name(v, self),
147 git_revision=self.git_revision)
148
149 def fetch(self, v: Verification, section: str,
150 conf: configparser.SectionProxy) -> str:
151 if 'git_revision' not in conf or 'release_name' not in conf:
152 raise Exception(
153 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
154 section)
155
156 ensure_git_rev_available(v, self)
157 return git_get_tarball(v, self)
158
159
160 class ChannelSearchPath(TarrableSearchPath):
161 def pin(self, v: Verification) -> ChannelPin:
162 if hasattr(self, 'git_revision'):
163 self.old_git_revision = self.git_revision
164 del self.git_revision
165
166 fetch(v, self)
167 parse_channel(v, self)
168 fetch_resources(v, self)
169 ensure_git_rev_available(v, self)
170 check_channel_contents(v, self)
171 return ChannelPin(
172 release_name=self.release_name,
173 tarball_url=self.table['nixexprs.tar.xz'].absolute_url,
174 tarball_sha256=self.table['nixexprs.tar.xz'].digest,
175 git_revision=self.git_revision)
176
177 # Lint TODO: Put tarball_url and tarball_sha256 in ChannelSearchPath
178 # pylint: disable=no-self-use
179 def fetch(self, v: Verification, section: str,
180 conf: configparser.SectionProxy) -> str:
181 if 'git_repo' not in conf or 'release_name' not in conf:
182 raise Exception(
183 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
184 section)
185
186 return fetch_with_nix_prefetch_url(
187 v, conf['tarball_url'], Digest16(
188 conf['tarball_sha256']))
189
190
191 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
192
193 def throw(error: OSError) -> None:
194 raise error
195
196 def join(x: str, y: str) -> str:
197 return y if x == '.' else os.path.join(x, y)
198
199 def recursive_files(d: str) -> Iterable[str]:
200 all_files: List[str] = []
201 for path, dirs, files in os.walk(d, onerror=throw):
202 rel = os.path.relpath(path, start=d)
203 all_files.extend(join(rel, f) for f in files)
204 for dir_or_link in dirs:
205 if os.path.islink(join(path, dir_or_link)):
206 all_files.append(join(rel, dir_or_link))
207 return all_files
208
209 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
210 return (f for f in files if not f.startswith('.git/'))
211
212 files = functools.reduce(
213 operator.or_, (set(
214 exclude_dot_git(
215 recursive_files(x))) for x in [a, b]))
216 return filecmp.cmpfiles(a, b, files, shallow=False)
217
218
219 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
220 v.status('Fetching channel')
221 request = urllib.request.urlopen(channel.channel_url, timeout=10)
222 channel.channel_html = request.read()
223 channel.forwarded_url = request.geturl()
224 v.result(request.status == 200) # type: ignore # (for old mypy)
225 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
226
227
228 def parse_channel(v: Verification, channel: TarrableSearchPath) -> None:
229 v.status('Parsing channel description as XML')
230 d = xml.dom.minidom.parseString(channel.channel_html)
231 v.ok()
232
233 v.status('Extracting release name:')
234 title_name = d.getElementsByTagName(
235 'title')[0].firstChild.nodeValue.split()[2]
236 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
237 v.status(title_name)
238 v.result(title_name == h1_name)
239 channel.release_name = title_name
240
241 v.status('Extracting git commit:')
242 git_commit_node = d.getElementsByTagName('tt')[0]
243 channel.git_revision = git_commit_node.firstChild.nodeValue
244 v.status(channel.git_revision)
245 v.ok()
246 v.status('Verifying git commit label')
247 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
248
249 v.status('Parsing table')
250 channel.table = {}
251 for row in d.getElementsByTagName('tr')[1:]:
252 name = row.childNodes[0].firstChild.firstChild.nodeValue
253 url = row.childNodes[0].firstChild.getAttribute('href')
254 size = int(row.childNodes[1].firstChild.nodeValue)
255 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
256 channel.table[name] = ChannelTableEntry(
257 url=url, digest=digest, size=size)
258 v.ok()
259
260
261 def digest_string(s: bytes) -> Digest16:
262 return Digest16(hashlib.sha256(s).hexdigest())
263
264
265 def digest_file(filename: str) -> Digest16:
266 hasher = hashlib.sha256()
267 with open(filename, 'rb') as f:
268 # pylint: disable=cell-var-from-loop
269 for block in iter(lambda: f.read(4096), b''):
270 hasher.update(block)
271 return Digest16(hasher.hexdigest())
272
273
274 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
275 v.status('Converting digest to base16')
276 process = subprocess.run(
277 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
278 v.result(process.returncode == 0)
279 return Digest16(process.stdout.decode().strip())
280
281
282 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
283 v.status('Converting digest to base32')
284 process = subprocess.run(
285 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
286 v.result(process.returncode == 0)
287 return Digest32(process.stdout.decode().strip())
288
289
290 def fetch_with_nix_prefetch_url(
291 v: Verification,
292 url: str,
293 digest: Digest16) -> str:
294 v.status('Fetching %s' % url)
295 process = subprocess.run(
296 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
297 v.result(process.returncode == 0)
298 prefetch_digest, path, empty = process.stdout.decode().split('\n')
299 assert empty == ''
300 v.check("Verifying nix-prefetch-url's digest",
301 to_Digest16(v, Digest32(prefetch_digest)) == digest)
302 v.status("Verifying file digest")
303 file_digest = digest_file(path)
304 v.result(file_digest == digest)
305 return path # type: ignore # (for old mypy)
306
307
308 def fetch_resources(v: Verification, channel: ChannelSearchPath) -> None:
309 for resource in ['git-revision', 'nixexprs.tar.xz']:
310 fields = channel.table[resource]
311 fields.absolute_url = urllib.parse.urljoin(
312 channel.forwarded_url, fields.url)
313 fields.file = fetch_with_nix_prefetch_url(
314 v, fields.absolute_url, fields.digest)
315 v.status('Verifying git commit on main page matches git commit in table')
316 v.result(
317 open(
318 channel.table['git-revision'].file).read(999) == channel.git_revision)
319
320
321 def git_cachedir(git_repo: str) -> str:
322 return os.path.join(
323 xdg.XDG_CACHE_HOME,
324 'pinch/git',
325 digest_string(git_repo.encode()))
326
327
328 def tarball_cache_file(channel: TarrableSearchPath) -> str:
329 return os.path.join(
330 xdg.XDG_CACHE_HOME,
331 'pinch/git-tarball',
332 '%s-%s-%s' %
333 (digest_string(channel.git_repo.encode()),
334 channel.git_revision,
335 channel.release_name))
336
337
338 def verify_git_ancestry(v: Verification, channel: TarrableSearchPath) -> None:
339 cachedir = git_cachedir(channel.git_repo)
340 v.status('Verifying rev is an ancestor of ref')
341 process = subprocess.run(['git',
342 '-C',
343 cachedir,
344 'merge-base',
345 '--is-ancestor',
346 channel.git_revision,
347 channel.git_ref])
348 v.result(process.returncode == 0)
349
350 if hasattr(channel, 'old_git_revision'):
351 v.status(
352 'Verifying rev is an ancestor of previous rev %s' %
353 channel.old_git_revision)
354 process = subprocess.run(['git',
355 '-C',
356 cachedir,
357 'merge-base',
358 '--is-ancestor',
359 channel.old_git_revision,
360 channel.git_revision])
361 v.result(process.returncode == 0)
362
363
364 def git_fetch(v: Verification, channel: TarrableSearchPath) -> None:
365 # It would be nice if we could share the nix git cache, but as of the time
366 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
367 # yet), and trying to straddle them both is too far into nix implementation
368 # details for my comfort. So we re-implement here half of nix.fetchGit.
369 # :(
370
371 cachedir = git_cachedir(channel.git_repo)
372 if not os.path.exists(cachedir):
373 v.status("Initializing git repo")
374 process = subprocess.run(
375 ['git', 'init', '--bare', cachedir])
376 v.result(process.returncode == 0)
377
378 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
379 # We don't use --force here because we want to abort and freak out if forced
380 # updates are happening.
381 process = subprocess.run(['git',
382 '-C',
383 cachedir,
384 'fetch',
385 channel.git_repo,
386 '%s:%s' % (channel.git_ref,
387 channel.git_ref)])
388 v.result(process.returncode == 0)
389
390 if hasattr(channel, 'git_revision'):
391 v.status('Verifying that fetch retrieved this rev')
392 process = subprocess.run(
393 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
394 v.result(process.returncode == 0)
395 else:
396 channel.git_revision = open(
397 os.path.join(
398 cachedir,
399 'refs',
400 'heads',
401 channel.git_ref)).read(999).strip()
402
403 verify_git_ancestry(v, channel)
404
405
406 def ensure_git_rev_available(
407 v: Verification,
408 channel: TarrableSearchPath) -> None:
409 cachedir = git_cachedir(channel.git_repo)
410 if os.path.exists(cachedir):
411 v.status('Checking if we already have this rev:')
412 process = subprocess.run(
413 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
414 if process.returncode == 0:
415 v.status('yes')
416 if process.returncode == 1:
417 v.status('no')
418 v.result(process.returncode == 0 or process.returncode == 1)
419 if process.returncode == 0:
420 verify_git_ancestry(v, channel)
421 return
422 git_fetch(v, channel)
423
424
425 def compare_tarball_and_git(
426 v: Verification,
427 channel: TarrableSearchPath,
428 channel_contents: str,
429 git_contents: str) -> None:
430 v.status('Comparing channel tarball with git checkout')
431 match, mismatch, errors = compare(os.path.join(
432 channel_contents, channel.release_name), git_contents)
433 v.ok()
434 v.check('%d files match' % len(match), len(match) > 0)
435 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
436 expected_errors = [
437 '.git-revision',
438 '.version-suffix',
439 'nixpkgs',
440 'programs.sqlite',
441 'svn-revision']
442 benign_errors = []
443 for ee in expected_errors:
444 if ee in errors:
445 errors.remove(ee)
446 benign_errors.append(ee)
447 v.check(
448 '%d unexpected incomparable files' %
449 len(errors),
450 len(errors) == 0)
451 v.check(
452 '(%d of %d expected incomparable files)' %
453 (len(benign_errors),
454 len(expected_errors)),
455 len(benign_errors) == len(expected_errors))
456
457
458 def extract_tarball(
459 v: Verification,
460 channel: TarrableSearchPath,
461 dest: str) -> None:
462 v.status('Extracting tarball %s' %
463 channel.table['nixexprs.tar.xz'].file)
464 shutil.unpack_archive(
465 channel.table['nixexprs.tar.xz'].file,
466 dest)
467 v.ok()
468
469
470 def git_checkout(
471 v: Verification,
472 channel: TarrableSearchPath,
473 dest: str) -> None:
474 v.status('Checking out corresponding git revision')
475 git = subprocess.Popen(['git',
476 '-C',
477 git_cachedir(channel.git_repo),
478 'archive',
479 channel.git_revision],
480 stdout=subprocess.PIPE)
481 tar = subprocess.Popen(
482 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
483 if git.stdout:
484 git.stdout.close()
485 tar.wait()
486 git.wait()
487 v.result(git.returncode == 0 and tar.returncode == 0)
488
489
490 def git_get_tarball(v: Verification, channel: TarrableSearchPath) -> str:
491 cache_file = tarball_cache_file(channel)
492 if os.path.exists(cache_file):
493 cached_tarball = open(cache_file).read(9999)
494 if os.path.exists(cached_tarball):
495 return cached_tarball
496
497 with tempfile.TemporaryDirectory() as output_dir:
498 output_filename = os.path.join(
499 output_dir, channel.release_name + '.tar.xz')
500 with open(output_filename, 'w') as output_file:
501 v.status(
502 'Generating tarball for git revision %s' %
503 channel.git_revision)
504 git = subprocess.Popen(['git',
505 '-C',
506 git_cachedir(channel.git_repo),
507 'archive',
508 '--prefix=%s/' % channel.release_name,
509 channel.git_revision],
510 stdout=subprocess.PIPE)
511 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
512 xz.wait()
513 git.wait()
514 v.result(git.returncode == 0 and xz.returncode == 0)
515
516 v.status('Putting tarball in Nix store')
517 process = subprocess.run(
518 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
519 v.result(process.returncode == 0)
520 store_tarball = process.stdout.decode().strip()
521
522 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
523 open(cache_file, 'w').write(store_tarball)
524 return store_tarball # type: ignore # (for old mypy)
525
526
527 def check_channel_metadata(
528 v: Verification,
529 channel: TarrableSearchPath,
530 channel_contents: str) -> None:
531 v.status('Verifying git commit in channel tarball')
532 v.result(
533 open(
534 os.path.join(
535 channel_contents,
536 channel.release_name,
537 '.git-revision')).read(999) == channel.git_revision)
538
539 v.status(
540 'Verifying version-suffix is a suffix of release name %s:' %
541 channel.release_name)
542 version_suffix = open(
543 os.path.join(
544 channel_contents,
545 channel.release_name,
546 '.version-suffix')).read(999)
547 v.status(version_suffix)
548 v.result(channel.release_name.endswith(version_suffix))
549
550
551 def check_channel_contents(
552 v: Verification,
553 channel: TarrableSearchPath) -> None:
554 with tempfile.TemporaryDirectory() as channel_contents, \
555 tempfile.TemporaryDirectory() as git_contents:
556
557 extract_tarball(v, channel, channel_contents)
558 check_channel_metadata(v, channel, channel_contents)
559
560 git_checkout(v, channel, git_contents)
561
562 compare_tarball_and_git(v, channel, channel_contents, git_contents)
563
564 v.status('Removing temporary directories')
565 v.ok()
566
567
568 def git_revision_name(v: Verification, channel: TarrableSearchPath) -> str:
569 v.status('Getting commit date')
570 process = subprocess.run(['git',
571 '-C',
572 git_cachedir(channel.git_repo),
573 'log',
574 '-n1',
575 '--format=%ct-%h',
576 '--abbrev=11',
577 '--no-show-signature',
578 channel.git_revision],
579 stdout=subprocess.PIPE)
580 v.result(process.returncode == 0 and process.stdout != b'')
581 return '%s-%s' % (os.path.basename(channel.git_repo),
582 process.stdout.decode().strip())
583
584
585 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
586 mapping: Mapping[str, Type[SearchPath]] = {
587 'alias': AliasSearchPath,
588 'channel': ChannelSearchPath,
589 'git': GitSearchPath,
590 }
591 return mapping[conf['type']](**dict(conf.items()))
592
593
594 def read_config(filename: str) -> configparser.ConfigParser:
595 config = configparser.ConfigParser()
596 config.read_file(open(filename), filename)
597 return config
598
599
600 def read_config_files(
601 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
602 merged_config: Dict[str, configparser.SectionProxy] = {}
603 for file in filenames:
604 config = read_config(file)
605 for section in config.sections():
606 if section in merged_config:
607 raise Exception('Duplicate channel "%s"' % section)
608 merged_config[section] = config[section]
609 return merged_config
610
611
612 def pinCommand(args: argparse.Namespace) -> None:
613 v = Verification()
614 config = read_config(args.channels_file)
615 for section in config.sections():
616 if args.channels and section not in args.channels:
617 continue
618
619 sp = read_search_path(config[section])
620
621 config[section].update(sp.pin(v)._asdict())
622
623 with open(args.channels_file, 'w') as configfile:
624 config.write(configfile)
625
626
627 def updateCommand(args: argparse.Namespace) -> None:
628 v = Verification()
629 exprs: Dict[str, str] = {}
630 config = read_config_files(args.channels_file)
631 for section in config:
632 sp = read_search_path(config[section])
633 if isinstance(sp, AliasSearchPath):
634 assert 'git_repo' not in config[section]
635 continue
636 tarball = sp.fetch(v, section, config[section])
637 exprs[section] = (
638 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
639 (config[section]['release_name'], tarball))
640
641 for section in config:
642 if 'alias_of' in config[section]:
643 exprs[section] = exprs[str(config[section]['alias_of'])]
644
645 command = [
646 'nix-env',
647 '--profile',
648 '/nix/var/nix/profiles/per-user/%s/channels' %
649 getpass.getuser(),
650 '--show-trace',
651 '--file',
652 '<nix/unpack-channel.nix>',
653 '--install',
654 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
655 if args.dry_run:
656 print(' '.join(map(shlex.quote, command)))
657 else:
658 v.status('Installing channels with nix-env')
659 process = subprocess.run(command)
660 v.result(process.returncode == 0)
661
662
663 def main() -> None:
664 parser = argparse.ArgumentParser(prog='pinch')
665 subparsers = parser.add_subparsers(dest='mode', required=True)
666 parser_pin = subparsers.add_parser('pin')
667 parser_pin.add_argument('channels_file', type=str)
668 parser_pin.add_argument('channels', type=str, nargs='*')
669 parser_pin.set_defaults(func=pinCommand)
670 parser_update = subparsers.add_parser('update')
671 parser_update.add_argument('--dry-run', action='store_true')
672 parser_update.add_argument('channels_file', type=str, nargs='+')
673 parser_update.set_defaults(func=updateCommand)
674 args = parser.parse_args()
675 args.func(args)
676
677
678 if __name__ == '__main__':
679 main()