]> git.scottworley.com Git - pinch/blob - pinch.py
d346a14f733a0ee071978b28b10e158cdfcaaf45
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tempfile
15 import types
16 import urllib.parse
17 import urllib.request
18 import xml.dom.minidom
19
20 from typing import (
21 Dict,
22 Iterable,
23 List,
24 NewType,
25 Tuple,
26 )
27
28 # Use xdg module when it's less painful to have as a dependency
29
30
31 class XDG(types.SimpleNamespace):
32 XDG_CACHE_HOME: str
33
34
35 xdg = XDG(
36 XDG_CACHE_HOME=os.getenv(
37 'XDG_CACHE_HOME',
38 os.path.expanduser('~/.cache')))
39
40
41 class VerificationError(Exception):
42 pass
43
44
45 class Verification:
46
47 def __init__(self) -> None:
48 self.line_length = 0
49
50 def status(self, s: str) -> None:
51 print(s, end=' ', file=sys.stderr, flush=True)
52 self.line_length += 1 + len(s) # Unicode??
53
54 @staticmethod
55 def _color(s: str, c: int) -> str:
56 return '\033[%2dm%s\033[00m' % (c, s)
57
58 def result(self, r: bool) -> None:
59 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
60 length = len(message)
61 cols = shutil.get_terminal_size().columns or 80
62 pad = (cols - (self.line_length + length)) % cols
63 print(' ' * pad + self._color(message, color), file=sys.stderr)
64 self.line_length = 0
65 if not r:
66 raise VerificationError()
67
68 def check(self, s: str, r: bool) -> None:
69 self.status(s)
70 self.result(r)
71
72 def ok(self) -> None:
73 self.result(True)
74
75
76 Digest16 = NewType('Digest16', str)
77 Digest32 = NewType('Digest32', str)
78
79
80 class ChannelTableEntry(types.SimpleNamespace):
81 absolute_url: str
82 digest: Digest16
83 file: str
84 size: int
85 url: str
86
87
88 class SearchPath(types.SimpleNamespace):
89 release_name: str
90
91
92 class Channel(SearchPath):
93 alias_of: str
94 channel_html: bytes
95 channel_url: str
96 forwarded_url: str
97 git_ref: str
98 git_repo: str
99 git_revision: str
100 old_git_revision: str
101 table: Dict[str, ChannelTableEntry]
102
103 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
104 if hasattr(self, 'alias_of'):
105 assert not hasattr(self, 'git_repo')
106 return
107
108 if hasattr(self, 'git_revision'):
109 self.old_git_revision = self.git_revision
110 del self.git_revision
111
112 if 'channel_url' in conf:
113 pin_channel(v, self)
114 conf['release_name'] = self.release_name
115 conf['tarball_url'] = self.table['nixexprs.tar.xz'].absolute_url
116 conf['tarball_sha256'] = self.table['nixexprs.tar.xz'].digest
117 else:
118 git_fetch(v, self)
119 conf['release_name'] = git_revision_name(v, self)
120 conf['git_revision'] = self.git_revision
121
122
123 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
124
125 def throw(error: OSError) -> None:
126 raise error
127
128 def join(x: str, y: str) -> str:
129 return y if x == '.' else os.path.join(x, y)
130
131 def recursive_files(d: str) -> Iterable[str]:
132 all_files: List[str] = []
133 for path, dirs, files in os.walk(d, onerror=throw):
134 rel = os.path.relpath(path, start=d)
135 all_files.extend(join(rel, f) for f in files)
136 for dir_or_link in dirs:
137 if os.path.islink(join(path, dir_or_link)):
138 all_files.append(join(rel, dir_or_link))
139 return all_files
140
141 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
142 return (f for f in files if not f.startswith('.git/'))
143
144 files = functools.reduce(
145 operator.or_, (set(
146 exclude_dot_git(
147 recursive_files(x))) for x in [a, b]))
148 return filecmp.cmpfiles(a, b, files, shallow=False)
149
150
151 def fetch(v: Verification, channel: Channel) -> None:
152 v.status('Fetching channel')
153 request = urllib.request.urlopen(channel.channel_url, timeout=10)
154 channel.channel_html = request.read()
155 channel.forwarded_url = request.geturl()
156 v.result(request.status == 200) # type: ignore # (for old mypy)
157 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
158
159
160 def parse_channel(v: Verification, channel: Channel) -> None:
161 v.status('Parsing channel description as XML')
162 d = xml.dom.minidom.parseString(channel.channel_html)
163 v.ok()
164
165 v.status('Extracting release name:')
166 title_name = d.getElementsByTagName(
167 'title')[0].firstChild.nodeValue.split()[2]
168 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
169 v.status(title_name)
170 v.result(title_name == h1_name)
171 channel.release_name = title_name
172
173 v.status('Extracting git commit:')
174 git_commit_node = d.getElementsByTagName('tt')[0]
175 channel.git_revision = git_commit_node.firstChild.nodeValue
176 v.status(channel.git_revision)
177 v.ok()
178 v.status('Verifying git commit label')
179 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
180
181 v.status('Parsing table')
182 channel.table = {}
183 for row in d.getElementsByTagName('tr')[1:]:
184 name = row.childNodes[0].firstChild.firstChild.nodeValue
185 url = row.childNodes[0].firstChild.getAttribute('href')
186 size = int(row.childNodes[1].firstChild.nodeValue)
187 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
188 channel.table[name] = ChannelTableEntry(
189 url=url, digest=digest, size=size)
190 v.ok()
191
192
193 def digest_string(s: bytes) -> Digest16:
194 return Digest16(hashlib.sha256(s).hexdigest())
195
196
197 def digest_file(filename: str) -> Digest16:
198 hasher = hashlib.sha256()
199 with open(filename, 'rb') as f:
200 # pylint: disable=cell-var-from-loop
201 for block in iter(lambda: f.read(4096), b''):
202 hasher.update(block)
203 return Digest16(hasher.hexdigest())
204
205
206 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
207 v.status('Converting digest to base16')
208 process = subprocess.run(
209 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
210 v.result(process.returncode == 0)
211 return Digest16(process.stdout.decode().strip())
212
213
214 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
215 v.status('Converting digest to base32')
216 process = subprocess.run(
217 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
218 v.result(process.returncode == 0)
219 return Digest32(process.stdout.decode().strip())
220
221
222 def fetch_with_nix_prefetch_url(
223 v: Verification,
224 url: str,
225 digest: Digest16) -> str:
226 v.status('Fetching %s' % url)
227 process = subprocess.run(
228 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
229 v.result(process.returncode == 0)
230 prefetch_digest, path, empty = process.stdout.decode().split('\n')
231 assert empty == ''
232 v.check("Verifying nix-prefetch-url's digest",
233 to_Digest16(v, Digest32(prefetch_digest)) == digest)
234 v.status("Verifying file digest")
235 file_digest = digest_file(path)
236 v.result(file_digest == digest)
237 return path # type: ignore # (for old mypy)
238
239
240 def fetch_resources(v: Verification, channel: Channel) -> None:
241 for resource in ['git-revision', 'nixexprs.tar.xz']:
242 fields = channel.table[resource]
243 fields.absolute_url = urllib.parse.urljoin(
244 channel.forwarded_url, fields.url)
245 fields.file = fetch_with_nix_prefetch_url(
246 v, fields.absolute_url, fields.digest)
247 v.status('Verifying git commit on main page matches git commit in table')
248 v.result(
249 open(
250 channel.table['git-revision'].file).read(999) == channel.git_revision)
251
252
253 def git_cachedir(git_repo: str) -> str:
254 return os.path.join(
255 xdg.XDG_CACHE_HOME,
256 'pinch/git',
257 digest_string(git_repo.encode()))
258
259
260 def tarball_cache_file(channel: Channel) -> str:
261 return os.path.join(
262 xdg.XDG_CACHE_HOME,
263 'pinch/git-tarball',
264 '%s-%s-%s' %
265 (digest_string(channel.git_repo.encode()),
266 channel.git_revision,
267 channel.release_name))
268
269
270 def verify_git_ancestry(v: Verification, channel: Channel) -> None:
271 cachedir = git_cachedir(channel.git_repo)
272 v.status('Verifying rev is an ancestor of ref')
273 process = subprocess.run(['git',
274 '-C',
275 cachedir,
276 'merge-base',
277 '--is-ancestor',
278 channel.git_revision,
279 channel.git_ref])
280 v.result(process.returncode == 0)
281
282 if hasattr(channel, 'old_git_revision'):
283 v.status(
284 'Verifying rev is an ancestor of previous rev %s' %
285 channel.old_git_revision)
286 process = subprocess.run(['git',
287 '-C',
288 cachedir,
289 'merge-base',
290 '--is-ancestor',
291 channel.old_git_revision,
292 channel.git_revision])
293 v.result(process.returncode == 0)
294
295
296 def git_fetch(v: Verification, channel: Channel) -> None:
297 # It would be nice if we could share the nix git cache, but as of the time
298 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
299 # yet), and trying to straddle them both is too far into nix implementation
300 # details for my comfort. So we re-implement here half of nix.fetchGit.
301 # :(
302
303 cachedir = git_cachedir(channel.git_repo)
304 if not os.path.exists(cachedir):
305 v.status("Initializing git repo")
306 process = subprocess.run(
307 ['git', 'init', '--bare', cachedir])
308 v.result(process.returncode == 0)
309
310 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
311 # We don't use --force here because we want to abort and freak out if forced
312 # updates are happening.
313 process = subprocess.run(['git',
314 '-C',
315 cachedir,
316 'fetch',
317 channel.git_repo,
318 '%s:%s' % (channel.git_ref,
319 channel.git_ref)])
320 v.result(process.returncode == 0)
321
322 if hasattr(channel, 'git_revision'):
323 v.status('Verifying that fetch retrieved this rev')
324 process = subprocess.run(
325 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
326 v.result(process.returncode == 0)
327 else:
328 channel.git_revision = open(
329 os.path.join(
330 cachedir,
331 'refs',
332 'heads',
333 channel.git_ref)).read(999).strip()
334
335 verify_git_ancestry(v, channel)
336
337
338 def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
339 cachedir = git_cachedir(channel.git_repo)
340 if os.path.exists(cachedir):
341 v.status('Checking if we already have this rev:')
342 process = subprocess.run(
343 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
344 if process.returncode == 0:
345 v.status('yes')
346 if process.returncode == 1:
347 v.status('no')
348 v.result(process.returncode == 0 or process.returncode == 1)
349 if process.returncode == 0:
350 verify_git_ancestry(v, channel)
351 return
352 git_fetch(v, channel)
353
354
355 def compare_tarball_and_git(
356 v: Verification,
357 channel: Channel,
358 channel_contents: str,
359 git_contents: str) -> None:
360 v.status('Comparing channel tarball with git checkout')
361 match, mismatch, errors = compare(os.path.join(
362 channel_contents, channel.release_name), git_contents)
363 v.ok()
364 v.check('%d files match' % len(match), len(match) > 0)
365 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
366 expected_errors = [
367 '.git-revision',
368 '.version-suffix',
369 'nixpkgs',
370 'programs.sqlite',
371 'svn-revision']
372 benign_errors = []
373 for ee in expected_errors:
374 if ee in errors:
375 errors.remove(ee)
376 benign_errors.append(ee)
377 v.check(
378 '%d unexpected incomparable files' %
379 len(errors),
380 len(errors) == 0)
381 v.check(
382 '(%d of %d expected incomparable files)' %
383 (len(benign_errors),
384 len(expected_errors)),
385 len(benign_errors) == len(expected_errors))
386
387
388 def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
389 v.status('Extracting tarball %s' %
390 channel.table['nixexprs.tar.xz'].file)
391 shutil.unpack_archive(
392 channel.table['nixexprs.tar.xz'].file,
393 dest)
394 v.ok()
395
396
397 def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
398 v.status('Checking out corresponding git revision')
399 git = subprocess.Popen(['git',
400 '-C',
401 git_cachedir(channel.git_repo),
402 'archive',
403 channel.git_revision],
404 stdout=subprocess.PIPE)
405 tar = subprocess.Popen(
406 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
407 if git.stdout:
408 git.stdout.close()
409 tar.wait()
410 git.wait()
411 v.result(git.returncode == 0 and tar.returncode == 0)
412
413
414 def git_get_tarball(v: Verification, channel: Channel) -> str:
415 cache_file = tarball_cache_file(channel)
416 if os.path.exists(cache_file):
417 cached_tarball = open(cache_file).read(9999)
418 if os.path.exists(cached_tarball):
419 return cached_tarball
420
421 with tempfile.TemporaryDirectory() as output_dir:
422 output_filename = os.path.join(
423 output_dir, channel.release_name + '.tar.xz')
424 with open(output_filename, 'w') as output_file:
425 v.status(
426 'Generating tarball for git revision %s' %
427 channel.git_revision)
428 git = subprocess.Popen(['git',
429 '-C',
430 git_cachedir(channel.git_repo),
431 'archive',
432 '--prefix=%s/' % channel.release_name,
433 channel.git_revision],
434 stdout=subprocess.PIPE)
435 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
436 xz.wait()
437 git.wait()
438 v.result(git.returncode == 0 and xz.returncode == 0)
439
440 v.status('Putting tarball in Nix store')
441 process = subprocess.run(
442 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
443 v.result(process.returncode == 0)
444 store_tarball = process.stdout.decode().strip()
445
446 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
447 open(cache_file, 'w').write(store_tarball)
448 return store_tarball # type: ignore # (for old mypy)
449
450
451 def check_channel_metadata(
452 v: Verification,
453 channel: Channel,
454 channel_contents: str) -> None:
455 v.status('Verifying git commit in channel tarball')
456 v.result(
457 open(
458 os.path.join(
459 channel_contents,
460 channel.release_name,
461 '.git-revision')).read(999) == channel.git_revision)
462
463 v.status(
464 'Verifying version-suffix is a suffix of release name %s:' %
465 channel.release_name)
466 version_suffix = open(
467 os.path.join(
468 channel_contents,
469 channel.release_name,
470 '.version-suffix')).read(999)
471 v.status(version_suffix)
472 v.result(channel.release_name.endswith(version_suffix))
473
474
475 def check_channel_contents(v: Verification, channel: Channel) -> None:
476 with tempfile.TemporaryDirectory() as channel_contents, \
477 tempfile.TemporaryDirectory() as git_contents:
478
479 extract_tarball(v, channel, channel_contents)
480 check_channel_metadata(v, channel, channel_contents)
481
482 git_checkout(v, channel, git_contents)
483
484 compare_tarball_and_git(v, channel, channel_contents, git_contents)
485
486 v.status('Removing temporary directories')
487 v.ok()
488
489
490 def pin_channel(v: Verification, channel: Channel) -> None:
491 fetch(v, channel)
492 parse_channel(v, channel)
493 fetch_resources(v, channel)
494 ensure_git_rev_available(v, channel)
495 check_channel_contents(v, channel)
496
497
498 def git_revision_name(v: Verification, channel: Channel) -> str:
499 v.status('Getting commit date')
500 process = subprocess.run(['git',
501 '-C',
502 git_cachedir(channel.git_repo),
503 'log',
504 '-n1',
505 '--format=%ct-%h',
506 '--abbrev=11',
507 '--no-show-signature',
508 channel.git_revision],
509 stdout=subprocess.PIPE)
510 v.result(process.returncode == 0 and process.stdout != b'')
511 return '%s-%s' % (os.path.basename(channel.git_repo),
512 process.stdout.decode().strip())
513
514
515 def read_search_path(conf: configparser.SectionProxy) -> Channel:
516 return Channel(**dict(conf.items()))
517
518
519 def read_config(filename: str) -> configparser.ConfigParser:
520 config = configparser.ConfigParser()
521 config.read_file(open(filename), filename)
522 return config
523
524
525 def pin(args: argparse.Namespace) -> None:
526 v = Verification()
527 config = read_config(args.channels_file)
528 for section in config.sections():
529 if args.channels and section not in args.channels:
530 continue
531
532 channel = read_search_path(config[section])
533
534 channel.pin(v, config[section])
535
536 with open(args.channels_file, 'w') as configfile:
537 config.write(configfile)
538
539
540 def fetch_channel(
541 v: Verification,
542 section: str,
543 conf: configparser.SectionProxy) -> str:
544 if 'git_repo' not in conf or 'release_name' not in conf:
545 raise Exception(
546 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
547 section)
548
549 if 'channel_url' in conf:
550 return fetch_with_nix_prefetch_url(
551 v, conf['tarball_url'], Digest16(
552 conf['tarball_sha256']))
553
554 channel = read_search_path(conf)
555 ensure_git_rev_available(v, channel)
556 return git_get_tarball(v, channel)
557
558
559 def update(args: argparse.Namespace) -> None:
560 v = Verification()
561 config = configparser.ConfigParser()
562 exprs: Dict[str, str] = {}
563 configs = [read_config(filename) for filename in args.channels_file]
564 for config in configs:
565 for section in config.sections():
566 if 'alias_of' in config[section]:
567 assert 'git_repo' not in config[section]
568 continue
569 tarball = fetch_channel(v, section, config[section])
570 if section in exprs:
571 raise Exception('Duplicate channel "%s"' % section)
572 exprs[section] = (
573 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
574 (config[section]['release_name'], tarball))
575
576 for config in configs:
577 for section in config.sections():
578 if 'alias_of' in config[section]:
579 if section in exprs:
580 raise Exception('Duplicate channel "%s"' % section)
581 exprs[section] = exprs[str(config[section]['alias_of'])]
582
583 command = [
584 'nix-env',
585 '--profile',
586 '/nix/var/nix/profiles/per-user/%s/channels' %
587 getpass.getuser(),
588 '--show-trace',
589 '--file',
590 '<nix/unpack-channel.nix>',
591 '--install',
592 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
593 if args.dry_run:
594 print(' '.join(map(shlex.quote, command)))
595 else:
596 v.status('Installing channels with nix-env')
597 process = subprocess.run(command)
598 v.result(process.returncode == 0)
599
600
601 def main() -> None:
602 parser = argparse.ArgumentParser(prog='pinch')
603 subparsers = parser.add_subparsers(dest='mode', required=True)
604 parser_pin = subparsers.add_parser('pin')
605 parser_pin.add_argument('channels_file', type=str)
606 parser_pin.add_argument('channels', type=str, nargs='*')
607 parser_pin.set_defaults(func=pin)
608 parser_update = subparsers.add_parser('update')
609 parser_update.add_argument('--dry-run', action='store_true')
610 parser_update.add_argument('channels_file', type=str, nargs='+')
611 parser_update.set_defaults(func=update)
612 args = parser.parse_args()
613 args.func(args)
614
615
616 if __name__ == '__main__':
617 main()