]> git.scottworley.com Git - pinch/blob - pinch.py
b76c8da36db18e4b608fad64c5e4bbd422b8c948
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 NewType,
26 Tuple,
27 )
28
29 # Use xdg module when it's less painful to have as a dependency
30
31
32 class XDG(types.SimpleNamespace):
33 XDG_CACHE_HOME: str
34
35
36 xdg = XDG(
37 XDG_CACHE_HOME=os.getenv(
38 'XDG_CACHE_HOME',
39 os.path.expanduser('~/.cache')))
40
41
42 class VerificationError(Exception):
43 pass
44
45
46 class Verification:
47
48 def __init__(self) -> None:
49 self.line_length = 0
50
51 def status(self, s: str) -> None:
52 print(s, end=' ', file=sys.stderr, flush=True)
53 self.line_length += 1 + len(s) # Unicode??
54
55 @staticmethod
56 def _color(s: str, c: int) -> str:
57 return '\033[%2dm%s\033[00m' % (c, s)
58
59 def result(self, r: bool) -> None:
60 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
61 length = len(message)
62 cols = shutil.get_terminal_size().columns or 80
63 pad = (cols - (self.line_length + length)) % cols
64 print(' ' * pad + self._color(message, color), file=sys.stderr)
65 self.line_length = 0
66 if not r:
67 raise VerificationError()
68
69 def check(self, s: str, r: bool) -> None:
70 self.status(s)
71 self.result(r)
72
73 def ok(self) -> None:
74 self.result(True)
75
76
77 Digest16 = NewType('Digest16', str)
78 Digest32 = NewType('Digest32', str)
79
80
81 class ChannelTableEntry(types.SimpleNamespace):
82 absolute_url: str
83 digest: Digest16
84 file: str
85 size: int
86 url: str
87
88
89 class SearchPath(types.SimpleNamespace, ABC):
90 release_name: str
91
92 @abstractmethod
93 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
94 pass
95
96
97 class AliasSearchPath(SearchPath):
98 alias_of: str
99
100 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
101 assert not hasattr(self, 'git_repo')
102
103
104 class TarrableSearchPath(SearchPath):
105 channel_html: bytes
106 channel_url: str
107 forwarded_url: str
108 git_ref: str
109 git_repo: str
110 git_revision: str
111 old_git_revision: str
112 table: Dict[str, ChannelTableEntry]
113
114 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
115 if hasattr(self, 'git_revision'):
116 self.old_git_revision = self.git_revision
117 del self.git_revision
118
119 if 'channel_url' in conf:
120 pin_channel(v, self)
121 conf['release_name'] = self.release_name
122 conf['tarball_url'] = self.table['nixexprs.tar.xz'].absolute_url
123 conf['tarball_sha256'] = self.table['nixexprs.tar.xz'].digest
124 else:
125 git_fetch(v, self)
126 conf['release_name'] = git_revision_name(v, self)
127 conf['git_revision'] = self.git_revision
128
129 def fetch(self, v: Verification, section: str,
130 conf: configparser.SectionProxy) -> str:
131 if 'git_repo' not in conf or 'release_name' not in conf:
132 raise Exception(
133 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
134 section)
135
136 if 'channel_url' in conf:
137 return fetch_with_nix_prefetch_url(
138 v, conf['tarball_url'], Digest16(
139 conf['tarball_sha256']))
140
141 ensure_git_rev_available(v, self)
142 return git_get_tarball(v, self)
143
144
145 class GitSearchPath(TarrableSearchPath):
146 pass
147
148
149 class ChannelSearchPath(TarrableSearchPath):
150 pass
151
152
153 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
154
155 def throw(error: OSError) -> None:
156 raise error
157
158 def join(x: str, y: str) -> str:
159 return y if x == '.' else os.path.join(x, y)
160
161 def recursive_files(d: str) -> Iterable[str]:
162 all_files: List[str] = []
163 for path, dirs, files in os.walk(d, onerror=throw):
164 rel = os.path.relpath(path, start=d)
165 all_files.extend(join(rel, f) for f in files)
166 for dir_or_link in dirs:
167 if os.path.islink(join(path, dir_or_link)):
168 all_files.append(join(rel, dir_or_link))
169 return all_files
170
171 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
172 return (f for f in files if not f.startswith('.git/'))
173
174 files = functools.reduce(
175 operator.or_, (set(
176 exclude_dot_git(
177 recursive_files(x))) for x in [a, b]))
178 return filecmp.cmpfiles(a, b, files, shallow=False)
179
180
181 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
182 v.status('Fetching channel')
183 request = urllib.request.urlopen(channel.channel_url, timeout=10)
184 channel.channel_html = request.read()
185 channel.forwarded_url = request.geturl()
186 v.result(request.status == 200) # type: ignore # (for old mypy)
187 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
188
189
190 def parse_channel(v: Verification, channel: TarrableSearchPath) -> None:
191 v.status('Parsing channel description as XML')
192 d = xml.dom.minidom.parseString(channel.channel_html)
193 v.ok()
194
195 v.status('Extracting release name:')
196 title_name = d.getElementsByTagName(
197 'title')[0].firstChild.nodeValue.split()[2]
198 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
199 v.status(title_name)
200 v.result(title_name == h1_name)
201 channel.release_name = title_name
202
203 v.status('Extracting git commit:')
204 git_commit_node = d.getElementsByTagName('tt')[0]
205 channel.git_revision = git_commit_node.firstChild.nodeValue
206 v.status(channel.git_revision)
207 v.ok()
208 v.status('Verifying git commit label')
209 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
210
211 v.status('Parsing table')
212 channel.table = {}
213 for row in d.getElementsByTagName('tr')[1:]:
214 name = row.childNodes[0].firstChild.firstChild.nodeValue
215 url = row.childNodes[0].firstChild.getAttribute('href')
216 size = int(row.childNodes[1].firstChild.nodeValue)
217 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
218 channel.table[name] = ChannelTableEntry(
219 url=url, digest=digest, size=size)
220 v.ok()
221
222
223 def digest_string(s: bytes) -> Digest16:
224 return Digest16(hashlib.sha256(s).hexdigest())
225
226
227 def digest_file(filename: str) -> Digest16:
228 hasher = hashlib.sha256()
229 with open(filename, 'rb') as f:
230 # pylint: disable=cell-var-from-loop
231 for block in iter(lambda: f.read(4096), b''):
232 hasher.update(block)
233 return Digest16(hasher.hexdigest())
234
235
236 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
237 v.status('Converting digest to base16')
238 process = subprocess.run(
239 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
240 v.result(process.returncode == 0)
241 return Digest16(process.stdout.decode().strip())
242
243
244 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
245 v.status('Converting digest to base32')
246 process = subprocess.run(
247 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
248 v.result(process.returncode == 0)
249 return Digest32(process.stdout.decode().strip())
250
251
252 def fetch_with_nix_prefetch_url(
253 v: Verification,
254 url: str,
255 digest: Digest16) -> str:
256 v.status('Fetching %s' % url)
257 process = subprocess.run(
258 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
259 v.result(process.returncode == 0)
260 prefetch_digest, path, empty = process.stdout.decode().split('\n')
261 assert empty == ''
262 v.check("Verifying nix-prefetch-url's digest",
263 to_Digest16(v, Digest32(prefetch_digest)) == digest)
264 v.status("Verifying file digest")
265 file_digest = digest_file(path)
266 v.result(file_digest == digest)
267 return path # type: ignore # (for old mypy)
268
269
270 def fetch_resources(v: Verification, channel: TarrableSearchPath) -> None:
271 for resource in ['git-revision', 'nixexprs.tar.xz']:
272 fields = channel.table[resource]
273 fields.absolute_url = urllib.parse.urljoin(
274 channel.forwarded_url, fields.url)
275 fields.file = fetch_with_nix_prefetch_url(
276 v, fields.absolute_url, fields.digest)
277 v.status('Verifying git commit on main page matches git commit in table')
278 v.result(
279 open(
280 channel.table['git-revision'].file).read(999) == channel.git_revision)
281
282
283 def git_cachedir(git_repo: str) -> str:
284 return os.path.join(
285 xdg.XDG_CACHE_HOME,
286 'pinch/git',
287 digest_string(git_repo.encode()))
288
289
290 def tarball_cache_file(channel: TarrableSearchPath) -> str:
291 return os.path.join(
292 xdg.XDG_CACHE_HOME,
293 'pinch/git-tarball',
294 '%s-%s-%s' %
295 (digest_string(channel.git_repo.encode()),
296 channel.git_revision,
297 channel.release_name))
298
299
300 def verify_git_ancestry(v: Verification, channel: TarrableSearchPath) -> None:
301 cachedir = git_cachedir(channel.git_repo)
302 v.status('Verifying rev is an ancestor of ref')
303 process = subprocess.run(['git',
304 '-C',
305 cachedir,
306 'merge-base',
307 '--is-ancestor',
308 channel.git_revision,
309 channel.git_ref])
310 v.result(process.returncode == 0)
311
312 if hasattr(channel, 'old_git_revision'):
313 v.status(
314 'Verifying rev is an ancestor of previous rev %s' %
315 channel.old_git_revision)
316 process = subprocess.run(['git',
317 '-C',
318 cachedir,
319 'merge-base',
320 '--is-ancestor',
321 channel.old_git_revision,
322 channel.git_revision])
323 v.result(process.returncode == 0)
324
325
326 def git_fetch(v: Verification, channel: TarrableSearchPath) -> None:
327 # It would be nice if we could share the nix git cache, but as of the time
328 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
329 # yet), and trying to straddle them both is too far into nix implementation
330 # details for my comfort. So we re-implement here half of nix.fetchGit.
331 # :(
332
333 cachedir = git_cachedir(channel.git_repo)
334 if not os.path.exists(cachedir):
335 v.status("Initializing git repo")
336 process = subprocess.run(
337 ['git', 'init', '--bare', cachedir])
338 v.result(process.returncode == 0)
339
340 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
341 # We don't use --force here because we want to abort and freak out if forced
342 # updates are happening.
343 process = subprocess.run(['git',
344 '-C',
345 cachedir,
346 'fetch',
347 channel.git_repo,
348 '%s:%s' % (channel.git_ref,
349 channel.git_ref)])
350 v.result(process.returncode == 0)
351
352 if hasattr(channel, 'git_revision'):
353 v.status('Verifying that fetch retrieved this rev')
354 process = subprocess.run(
355 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
356 v.result(process.returncode == 0)
357 else:
358 channel.git_revision = open(
359 os.path.join(
360 cachedir,
361 'refs',
362 'heads',
363 channel.git_ref)).read(999).strip()
364
365 verify_git_ancestry(v, channel)
366
367
368 def ensure_git_rev_available(
369 v: Verification,
370 channel: TarrableSearchPath) -> None:
371 cachedir = git_cachedir(channel.git_repo)
372 if os.path.exists(cachedir):
373 v.status('Checking if we already have this rev:')
374 process = subprocess.run(
375 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
376 if process.returncode == 0:
377 v.status('yes')
378 if process.returncode == 1:
379 v.status('no')
380 v.result(process.returncode == 0 or process.returncode == 1)
381 if process.returncode == 0:
382 verify_git_ancestry(v, channel)
383 return
384 git_fetch(v, channel)
385
386
387 def compare_tarball_and_git(
388 v: Verification,
389 channel: TarrableSearchPath,
390 channel_contents: str,
391 git_contents: str) -> None:
392 v.status('Comparing channel tarball with git checkout')
393 match, mismatch, errors = compare(os.path.join(
394 channel_contents, channel.release_name), git_contents)
395 v.ok()
396 v.check('%d files match' % len(match), len(match) > 0)
397 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
398 expected_errors = [
399 '.git-revision',
400 '.version-suffix',
401 'nixpkgs',
402 'programs.sqlite',
403 'svn-revision']
404 benign_errors = []
405 for ee in expected_errors:
406 if ee in errors:
407 errors.remove(ee)
408 benign_errors.append(ee)
409 v.check(
410 '%d unexpected incomparable files' %
411 len(errors),
412 len(errors) == 0)
413 v.check(
414 '(%d of %d expected incomparable files)' %
415 (len(benign_errors),
416 len(expected_errors)),
417 len(benign_errors) == len(expected_errors))
418
419
420 def extract_tarball(
421 v: Verification,
422 channel: TarrableSearchPath,
423 dest: str) -> None:
424 v.status('Extracting tarball %s' %
425 channel.table['nixexprs.tar.xz'].file)
426 shutil.unpack_archive(
427 channel.table['nixexprs.tar.xz'].file,
428 dest)
429 v.ok()
430
431
432 def git_checkout(
433 v: Verification,
434 channel: TarrableSearchPath,
435 dest: str) -> None:
436 v.status('Checking out corresponding git revision')
437 git = subprocess.Popen(['git',
438 '-C',
439 git_cachedir(channel.git_repo),
440 'archive',
441 channel.git_revision],
442 stdout=subprocess.PIPE)
443 tar = subprocess.Popen(
444 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
445 if git.stdout:
446 git.stdout.close()
447 tar.wait()
448 git.wait()
449 v.result(git.returncode == 0 and tar.returncode == 0)
450
451
452 def git_get_tarball(v: Verification, channel: TarrableSearchPath) -> str:
453 cache_file = tarball_cache_file(channel)
454 if os.path.exists(cache_file):
455 cached_tarball = open(cache_file).read(9999)
456 if os.path.exists(cached_tarball):
457 return cached_tarball
458
459 with tempfile.TemporaryDirectory() as output_dir:
460 output_filename = os.path.join(
461 output_dir, channel.release_name + '.tar.xz')
462 with open(output_filename, 'w') as output_file:
463 v.status(
464 'Generating tarball for git revision %s' %
465 channel.git_revision)
466 git = subprocess.Popen(['git',
467 '-C',
468 git_cachedir(channel.git_repo),
469 'archive',
470 '--prefix=%s/' % channel.release_name,
471 channel.git_revision],
472 stdout=subprocess.PIPE)
473 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
474 xz.wait()
475 git.wait()
476 v.result(git.returncode == 0 and xz.returncode == 0)
477
478 v.status('Putting tarball in Nix store')
479 process = subprocess.run(
480 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
481 v.result(process.returncode == 0)
482 store_tarball = process.stdout.decode().strip()
483
484 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
485 open(cache_file, 'w').write(store_tarball)
486 return store_tarball # type: ignore # (for old mypy)
487
488
489 def check_channel_metadata(
490 v: Verification,
491 channel: TarrableSearchPath,
492 channel_contents: str) -> None:
493 v.status('Verifying git commit in channel tarball')
494 v.result(
495 open(
496 os.path.join(
497 channel_contents,
498 channel.release_name,
499 '.git-revision')).read(999) == channel.git_revision)
500
501 v.status(
502 'Verifying version-suffix is a suffix of release name %s:' %
503 channel.release_name)
504 version_suffix = open(
505 os.path.join(
506 channel_contents,
507 channel.release_name,
508 '.version-suffix')).read(999)
509 v.status(version_suffix)
510 v.result(channel.release_name.endswith(version_suffix))
511
512
513 def check_channel_contents(
514 v: Verification,
515 channel: TarrableSearchPath) -> None:
516 with tempfile.TemporaryDirectory() as channel_contents, \
517 tempfile.TemporaryDirectory() as git_contents:
518
519 extract_tarball(v, channel, channel_contents)
520 check_channel_metadata(v, channel, channel_contents)
521
522 git_checkout(v, channel, git_contents)
523
524 compare_tarball_and_git(v, channel, channel_contents, git_contents)
525
526 v.status('Removing temporary directories')
527 v.ok()
528
529
530 def pin_channel(v: Verification, channel: TarrableSearchPath) -> None:
531 fetch(v, channel)
532 parse_channel(v, channel)
533 fetch_resources(v, channel)
534 ensure_git_rev_available(v, channel)
535 check_channel_contents(v, channel)
536
537
538 def git_revision_name(v: Verification, channel: TarrableSearchPath) -> str:
539 v.status('Getting commit date')
540 process = subprocess.run(['git',
541 '-C',
542 git_cachedir(channel.git_repo),
543 'log',
544 '-n1',
545 '--format=%ct-%h',
546 '--abbrev=11',
547 '--no-show-signature',
548 channel.git_revision],
549 stdout=subprocess.PIPE)
550 v.result(process.returncode == 0 and process.stdout != b'')
551 return '%s-%s' % (os.path.basename(channel.git_repo),
552 process.stdout.decode().strip())
553
554
555 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
556 if 'alias_of' in conf:
557 return AliasSearchPath(**dict(conf.items()))
558 if 'channel_url' in conf:
559 return ChannelSearchPath(**dict(conf.items()))
560 return GitSearchPath(**dict(conf.items()))
561
562
563 def read_config(filename: str) -> configparser.ConfigParser:
564 config = configparser.ConfigParser()
565 config.read_file(open(filename), filename)
566 return config
567
568
569 def read_config_files(
570 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
571 merged_config: Dict[str, configparser.SectionProxy] = {}
572 for file in filenames:
573 config = read_config(file)
574 for section in config.sections():
575 if section in merged_config:
576 raise Exception('Duplicate channel "%s"' % section)
577 merged_config[section] = config[section]
578 return merged_config
579
580
581 def pin(args: argparse.Namespace) -> None:
582 v = Verification()
583 config = read_config(args.channels_file)
584 for section in config.sections():
585 if args.channels and section not in args.channels:
586 continue
587
588 sp = read_search_path(config[section])
589
590 sp.pin(v, config[section])
591
592 with open(args.channels_file, 'w') as configfile:
593 config.write(configfile)
594
595
596 def update(args: argparse.Namespace) -> None:
597 v = Verification()
598 exprs: Dict[str, str] = {}
599 config = read_config_files(args.channels_file)
600 for section in config:
601 sp = read_search_path(config[section])
602 if isinstance(sp, AliasSearchPath):
603 assert 'git_repo' not in config[section]
604 continue
605 tarball = sp.fetch(v, section, config[section])
606 exprs[section] = (
607 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
608 (config[section]['release_name'], tarball))
609
610 for section in config:
611 if 'alias_of' in config[section]:
612 exprs[section] = exprs[str(config[section]['alias_of'])]
613
614 command = [
615 'nix-env',
616 '--profile',
617 '/nix/var/nix/profiles/per-user/%s/channels' %
618 getpass.getuser(),
619 '--show-trace',
620 '--file',
621 '<nix/unpack-channel.nix>',
622 '--install',
623 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
624 if args.dry_run:
625 print(' '.join(map(shlex.quote, command)))
626 else:
627 v.status('Installing channels with nix-env')
628 process = subprocess.run(command)
629 v.result(process.returncode == 0)
630
631
632 def main() -> None:
633 parser = argparse.ArgumentParser(prog='pinch')
634 subparsers = parser.add_subparsers(dest='mode', required=True)
635 parser_pin = subparsers.add_parser('pin')
636 parser_pin.add_argument('channels_file', type=str)
637 parser_pin.add_argument('channels', type=str, nargs='*')
638 parser_pin.set_defaults(func=pin)
639 parser_update = subparsers.add_parser('update')
640 parser_update.add_argument('--dry-run', action='store_true')
641 parser_update.add_argument('channels_file', type=str, nargs='+')
642 parser_update.set_defaults(func=update)
643 args = parser.parse_args()
644 args.func(args)
645
646
647 if __name__ == '__main__':
648 main()