]> git.scottworley.com Git - pinch/blob - pinch.py
Rename Channel -> TarrableSearchPath
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 NewType,
26 Tuple,
27 )
28
29 # Use xdg module when it's less painful to have as a dependency
30
31
32 class XDG(types.SimpleNamespace):
33 XDG_CACHE_HOME: str
34
35
36 xdg = XDG(
37 XDG_CACHE_HOME=os.getenv(
38 'XDG_CACHE_HOME',
39 os.path.expanduser('~/.cache')))
40
41
42 class VerificationError(Exception):
43 pass
44
45
46 class Verification:
47
48 def __init__(self) -> None:
49 self.line_length = 0
50
51 def status(self, s: str) -> None:
52 print(s, end=' ', file=sys.stderr, flush=True)
53 self.line_length += 1 + len(s) # Unicode??
54
55 @staticmethod
56 def _color(s: str, c: int) -> str:
57 return '\033[%2dm%s\033[00m' % (c, s)
58
59 def result(self, r: bool) -> None:
60 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
61 length = len(message)
62 cols = shutil.get_terminal_size().columns or 80
63 pad = (cols - (self.line_length + length)) % cols
64 print(' ' * pad + self._color(message, color), file=sys.stderr)
65 self.line_length = 0
66 if not r:
67 raise VerificationError()
68
69 def check(self, s: str, r: bool) -> None:
70 self.status(s)
71 self.result(r)
72
73 def ok(self) -> None:
74 self.result(True)
75
76
77 Digest16 = NewType('Digest16', str)
78 Digest32 = NewType('Digest32', str)
79
80
81 class ChannelTableEntry(types.SimpleNamespace):
82 absolute_url: str
83 digest: Digest16
84 file: str
85 size: int
86 url: str
87
88
89 class SearchPath(types.SimpleNamespace, ABC):
90 release_name: str
91
92 @abstractmethod
93 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
94 pass
95
96
97 class AliasSearchPath(SearchPath):
98 alias_of: str
99
100 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
101 assert not hasattr(self, 'git_repo')
102
103
104 class TarrableSearchPath(SearchPath):
105 channel_html: bytes
106 channel_url: str
107 forwarded_url: str
108 git_ref: str
109 git_repo: str
110 git_revision: str
111 old_git_revision: str
112 table: Dict[str, ChannelTableEntry]
113
114 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
115 if hasattr(self, 'git_revision'):
116 self.old_git_revision = self.git_revision
117 del self.git_revision
118
119 if 'channel_url' in conf:
120 pin_channel(v, self)
121 conf['release_name'] = self.release_name
122 conf['tarball_url'] = self.table['nixexprs.tar.xz'].absolute_url
123 conf['tarball_sha256'] = self.table['nixexprs.tar.xz'].digest
124 else:
125 git_fetch(v, self)
126 conf['release_name'] = git_revision_name(v, self)
127 conf['git_revision'] = self.git_revision
128
129 def fetch(self, v: Verification, section: str,
130 conf: configparser.SectionProxy) -> str:
131 if 'git_repo' not in conf or 'release_name' not in conf:
132 raise Exception(
133 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
134 section)
135
136 if 'channel_url' in conf:
137 return fetch_with_nix_prefetch_url(
138 v, conf['tarball_url'], Digest16(
139 conf['tarball_sha256']))
140
141 ensure_git_rev_available(v, self)
142 return git_get_tarball(v, self)
143
144
145 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
146
147 def throw(error: OSError) -> None:
148 raise error
149
150 def join(x: str, y: str) -> str:
151 return y if x == '.' else os.path.join(x, y)
152
153 def recursive_files(d: str) -> Iterable[str]:
154 all_files: List[str] = []
155 for path, dirs, files in os.walk(d, onerror=throw):
156 rel = os.path.relpath(path, start=d)
157 all_files.extend(join(rel, f) for f in files)
158 for dir_or_link in dirs:
159 if os.path.islink(join(path, dir_or_link)):
160 all_files.append(join(rel, dir_or_link))
161 return all_files
162
163 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
164 return (f for f in files if not f.startswith('.git/'))
165
166 files = functools.reduce(
167 operator.or_, (set(
168 exclude_dot_git(
169 recursive_files(x))) for x in [a, b]))
170 return filecmp.cmpfiles(a, b, files, shallow=False)
171
172
173 def fetch(v: Verification, channel: TarrableSearchPath) -> None:
174 v.status('Fetching channel')
175 request = urllib.request.urlopen(channel.channel_url, timeout=10)
176 channel.channel_html = request.read()
177 channel.forwarded_url = request.geturl()
178 v.result(request.status == 200) # type: ignore # (for old mypy)
179 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
180
181
182 def parse_channel(v: Verification, channel: TarrableSearchPath) -> None:
183 v.status('Parsing channel description as XML')
184 d = xml.dom.minidom.parseString(channel.channel_html)
185 v.ok()
186
187 v.status('Extracting release name:')
188 title_name = d.getElementsByTagName(
189 'title')[0].firstChild.nodeValue.split()[2]
190 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
191 v.status(title_name)
192 v.result(title_name == h1_name)
193 channel.release_name = title_name
194
195 v.status('Extracting git commit:')
196 git_commit_node = d.getElementsByTagName('tt')[0]
197 channel.git_revision = git_commit_node.firstChild.nodeValue
198 v.status(channel.git_revision)
199 v.ok()
200 v.status('Verifying git commit label')
201 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
202
203 v.status('Parsing table')
204 channel.table = {}
205 for row in d.getElementsByTagName('tr')[1:]:
206 name = row.childNodes[0].firstChild.firstChild.nodeValue
207 url = row.childNodes[0].firstChild.getAttribute('href')
208 size = int(row.childNodes[1].firstChild.nodeValue)
209 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
210 channel.table[name] = ChannelTableEntry(
211 url=url, digest=digest, size=size)
212 v.ok()
213
214
215 def digest_string(s: bytes) -> Digest16:
216 return Digest16(hashlib.sha256(s).hexdigest())
217
218
219 def digest_file(filename: str) -> Digest16:
220 hasher = hashlib.sha256()
221 with open(filename, 'rb') as f:
222 # pylint: disable=cell-var-from-loop
223 for block in iter(lambda: f.read(4096), b''):
224 hasher.update(block)
225 return Digest16(hasher.hexdigest())
226
227
228 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
229 v.status('Converting digest to base16')
230 process = subprocess.run(
231 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
232 v.result(process.returncode == 0)
233 return Digest16(process.stdout.decode().strip())
234
235
236 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
237 v.status('Converting digest to base32')
238 process = subprocess.run(
239 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
240 v.result(process.returncode == 0)
241 return Digest32(process.stdout.decode().strip())
242
243
244 def fetch_with_nix_prefetch_url(
245 v: Verification,
246 url: str,
247 digest: Digest16) -> str:
248 v.status('Fetching %s' % url)
249 process = subprocess.run(
250 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
251 v.result(process.returncode == 0)
252 prefetch_digest, path, empty = process.stdout.decode().split('\n')
253 assert empty == ''
254 v.check("Verifying nix-prefetch-url's digest",
255 to_Digest16(v, Digest32(prefetch_digest)) == digest)
256 v.status("Verifying file digest")
257 file_digest = digest_file(path)
258 v.result(file_digest == digest)
259 return path # type: ignore # (for old mypy)
260
261
262 def fetch_resources(v: Verification, channel: TarrableSearchPath) -> None:
263 for resource in ['git-revision', 'nixexprs.tar.xz']:
264 fields = channel.table[resource]
265 fields.absolute_url = urllib.parse.urljoin(
266 channel.forwarded_url, fields.url)
267 fields.file = fetch_with_nix_prefetch_url(
268 v, fields.absolute_url, fields.digest)
269 v.status('Verifying git commit on main page matches git commit in table')
270 v.result(
271 open(
272 channel.table['git-revision'].file).read(999) == channel.git_revision)
273
274
275 def git_cachedir(git_repo: str) -> str:
276 return os.path.join(
277 xdg.XDG_CACHE_HOME,
278 'pinch/git',
279 digest_string(git_repo.encode()))
280
281
282 def tarball_cache_file(channel: TarrableSearchPath) -> str:
283 return os.path.join(
284 xdg.XDG_CACHE_HOME,
285 'pinch/git-tarball',
286 '%s-%s-%s' %
287 (digest_string(channel.git_repo.encode()),
288 channel.git_revision,
289 channel.release_name))
290
291
292 def verify_git_ancestry(v: Verification, channel: TarrableSearchPath) -> None:
293 cachedir = git_cachedir(channel.git_repo)
294 v.status('Verifying rev is an ancestor of ref')
295 process = subprocess.run(['git',
296 '-C',
297 cachedir,
298 'merge-base',
299 '--is-ancestor',
300 channel.git_revision,
301 channel.git_ref])
302 v.result(process.returncode == 0)
303
304 if hasattr(channel, 'old_git_revision'):
305 v.status(
306 'Verifying rev is an ancestor of previous rev %s' %
307 channel.old_git_revision)
308 process = subprocess.run(['git',
309 '-C',
310 cachedir,
311 'merge-base',
312 '--is-ancestor',
313 channel.old_git_revision,
314 channel.git_revision])
315 v.result(process.returncode == 0)
316
317
318 def git_fetch(v: Verification, channel: TarrableSearchPath) -> None:
319 # It would be nice if we could share the nix git cache, but as of the time
320 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
321 # yet), and trying to straddle them both is too far into nix implementation
322 # details for my comfort. So we re-implement here half of nix.fetchGit.
323 # :(
324
325 cachedir = git_cachedir(channel.git_repo)
326 if not os.path.exists(cachedir):
327 v.status("Initializing git repo")
328 process = subprocess.run(
329 ['git', 'init', '--bare', cachedir])
330 v.result(process.returncode == 0)
331
332 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
333 # We don't use --force here because we want to abort and freak out if forced
334 # updates are happening.
335 process = subprocess.run(['git',
336 '-C',
337 cachedir,
338 'fetch',
339 channel.git_repo,
340 '%s:%s' % (channel.git_ref,
341 channel.git_ref)])
342 v.result(process.returncode == 0)
343
344 if hasattr(channel, 'git_revision'):
345 v.status('Verifying that fetch retrieved this rev')
346 process = subprocess.run(
347 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
348 v.result(process.returncode == 0)
349 else:
350 channel.git_revision = open(
351 os.path.join(
352 cachedir,
353 'refs',
354 'heads',
355 channel.git_ref)).read(999).strip()
356
357 verify_git_ancestry(v, channel)
358
359
360 def ensure_git_rev_available(
361 v: Verification,
362 channel: TarrableSearchPath) -> None:
363 cachedir = git_cachedir(channel.git_repo)
364 if os.path.exists(cachedir):
365 v.status('Checking if we already have this rev:')
366 process = subprocess.run(
367 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
368 if process.returncode == 0:
369 v.status('yes')
370 if process.returncode == 1:
371 v.status('no')
372 v.result(process.returncode == 0 or process.returncode == 1)
373 if process.returncode == 0:
374 verify_git_ancestry(v, channel)
375 return
376 git_fetch(v, channel)
377
378
379 def compare_tarball_and_git(
380 v: Verification,
381 channel: TarrableSearchPath,
382 channel_contents: str,
383 git_contents: str) -> None:
384 v.status('Comparing channel tarball with git checkout')
385 match, mismatch, errors = compare(os.path.join(
386 channel_contents, channel.release_name), git_contents)
387 v.ok()
388 v.check('%d files match' % len(match), len(match) > 0)
389 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
390 expected_errors = [
391 '.git-revision',
392 '.version-suffix',
393 'nixpkgs',
394 'programs.sqlite',
395 'svn-revision']
396 benign_errors = []
397 for ee in expected_errors:
398 if ee in errors:
399 errors.remove(ee)
400 benign_errors.append(ee)
401 v.check(
402 '%d unexpected incomparable files' %
403 len(errors),
404 len(errors) == 0)
405 v.check(
406 '(%d of %d expected incomparable files)' %
407 (len(benign_errors),
408 len(expected_errors)),
409 len(benign_errors) == len(expected_errors))
410
411
412 def extract_tarball(
413 v: Verification,
414 channel: TarrableSearchPath,
415 dest: str) -> None:
416 v.status('Extracting tarball %s' %
417 channel.table['nixexprs.tar.xz'].file)
418 shutil.unpack_archive(
419 channel.table['nixexprs.tar.xz'].file,
420 dest)
421 v.ok()
422
423
424 def git_checkout(
425 v: Verification,
426 channel: TarrableSearchPath,
427 dest: str) -> None:
428 v.status('Checking out corresponding git revision')
429 git = subprocess.Popen(['git',
430 '-C',
431 git_cachedir(channel.git_repo),
432 'archive',
433 channel.git_revision],
434 stdout=subprocess.PIPE)
435 tar = subprocess.Popen(
436 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
437 if git.stdout:
438 git.stdout.close()
439 tar.wait()
440 git.wait()
441 v.result(git.returncode == 0 and tar.returncode == 0)
442
443
444 def git_get_tarball(v: Verification, channel: TarrableSearchPath) -> str:
445 cache_file = tarball_cache_file(channel)
446 if os.path.exists(cache_file):
447 cached_tarball = open(cache_file).read(9999)
448 if os.path.exists(cached_tarball):
449 return cached_tarball
450
451 with tempfile.TemporaryDirectory() as output_dir:
452 output_filename = os.path.join(
453 output_dir, channel.release_name + '.tar.xz')
454 with open(output_filename, 'w') as output_file:
455 v.status(
456 'Generating tarball for git revision %s' %
457 channel.git_revision)
458 git = subprocess.Popen(['git',
459 '-C',
460 git_cachedir(channel.git_repo),
461 'archive',
462 '--prefix=%s/' % channel.release_name,
463 channel.git_revision],
464 stdout=subprocess.PIPE)
465 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
466 xz.wait()
467 git.wait()
468 v.result(git.returncode == 0 and xz.returncode == 0)
469
470 v.status('Putting tarball in Nix store')
471 process = subprocess.run(
472 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
473 v.result(process.returncode == 0)
474 store_tarball = process.stdout.decode().strip()
475
476 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
477 open(cache_file, 'w').write(store_tarball)
478 return store_tarball # type: ignore # (for old mypy)
479
480
481 def check_channel_metadata(
482 v: Verification,
483 channel: TarrableSearchPath,
484 channel_contents: str) -> None:
485 v.status('Verifying git commit in channel tarball')
486 v.result(
487 open(
488 os.path.join(
489 channel_contents,
490 channel.release_name,
491 '.git-revision')).read(999) == channel.git_revision)
492
493 v.status(
494 'Verifying version-suffix is a suffix of release name %s:' %
495 channel.release_name)
496 version_suffix = open(
497 os.path.join(
498 channel_contents,
499 channel.release_name,
500 '.version-suffix')).read(999)
501 v.status(version_suffix)
502 v.result(channel.release_name.endswith(version_suffix))
503
504
505 def check_channel_contents(
506 v: Verification,
507 channel: TarrableSearchPath) -> None:
508 with tempfile.TemporaryDirectory() as channel_contents, \
509 tempfile.TemporaryDirectory() as git_contents:
510
511 extract_tarball(v, channel, channel_contents)
512 check_channel_metadata(v, channel, channel_contents)
513
514 git_checkout(v, channel, git_contents)
515
516 compare_tarball_and_git(v, channel, channel_contents, git_contents)
517
518 v.status('Removing temporary directories')
519 v.ok()
520
521
522 def pin_channel(v: Verification, channel: TarrableSearchPath) -> None:
523 fetch(v, channel)
524 parse_channel(v, channel)
525 fetch_resources(v, channel)
526 ensure_git_rev_available(v, channel)
527 check_channel_contents(v, channel)
528
529
530 def git_revision_name(v: Verification, channel: TarrableSearchPath) -> str:
531 v.status('Getting commit date')
532 process = subprocess.run(['git',
533 '-C',
534 git_cachedir(channel.git_repo),
535 'log',
536 '-n1',
537 '--format=%ct-%h',
538 '--abbrev=11',
539 '--no-show-signature',
540 channel.git_revision],
541 stdout=subprocess.PIPE)
542 v.result(process.returncode == 0 and process.stdout != b'')
543 return '%s-%s' % (os.path.basename(channel.git_repo),
544 process.stdout.decode().strip())
545
546
547 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
548 if 'alias_of' in conf:
549 return AliasSearchPath(**dict(conf.items()))
550 return TarrableSearchPath(**dict(conf.items()))
551
552
553 def read_config(filename: str) -> configparser.ConfigParser:
554 config = configparser.ConfigParser()
555 config.read_file(open(filename), filename)
556 return config
557
558
559 def read_config_files(
560 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
561 merged_config: Dict[str, configparser.SectionProxy] = {}
562 for file in filenames:
563 config = read_config(file)
564 for section in config.sections():
565 if section in merged_config:
566 raise Exception('Duplicate channel "%s"' % section)
567 merged_config[section] = config[section]
568 return merged_config
569
570
571 def pin(args: argparse.Namespace) -> None:
572 v = Verification()
573 config = read_config(args.channels_file)
574 for section in config.sections():
575 if args.channels and section not in args.channels:
576 continue
577
578 sp = read_search_path(config[section])
579
580 sp.pin(v, config[section])
581
582 with open(args.channels_file, 'w') as configfile:
583 config.write(configfile)
584
585
586 def update(args: argparse.Namespace) -> None:
587 v = Verification()
588 exprs: Dict[str, str] = {}
589 config = read_config_files(args.channels_file)
590 for section in config:
591 sp = read_search_path(config[section])
592 if isinstance(sp, AliasSearchPath):
593 assert 'git_repo' not in config[section]
594 continue
595 tarball = sp.fetch(v, section, config[section])
596 exprs[section] = (
597 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
598 (config[section]['release_name'], tarball))
599
600 for section in config:
601 if 'alias_of' in config[section]:
602 exprs[section] = exprs[str(config[section]['alias_of'])]
603
604 command = [
605 'nix-env',
606 '--profile',
607 '/nix/var/nix/profiles/per-user/%s/channels' %
608 getpass.getuser(),
609 '--show-trace',
610 '--file',
611 '<nix/unpack-channel.nix>',
612 '--install',
613 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
614 if args.dry_run:
615 print(' '.join(map(shlex.quote, command)))
616 else:
617 v.status('Installing channels with nix-env')
618 process = subprocess.run(command)
619 v.result(process.returncode == 0)
620
621
622 def main() -> None:
623 parser = argparse.ArgumentParser(prog='pinch')
624 subparsers = parser.add_subparsers(dest='mode', required=True)
625 parser_pin = subparsers.add_parser('pin')
626 parser_pin.add_argument('channels_file', type=str)
627 parser_pin.add_argument('channels', type=str, nargs='*')
628 parser_pin.set_defaults(func=pin)
629 parser_update = subparsers.add_parser('update')
630 parser_update.add_argument('--dry-run', action='store_true')
631 parser_update.add_argument('channels_file', type=str, nargs='+')
632 parser_update.set_defaults(func=update)
633 args = parser.parse_args()
634 args.func(args)
635
636
637 if __name__ == '__main__':
638 main()