]> git.scottworley.com Git - pinch/blob - pinch.py
a8af6a67aa6f50b2f6a3dbfba53b45a166d5b71d
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tempfile
15 import types
16 import urllib.parse
17 import urllib.request
18 import xml.dom.minidom
19
20 from typing import (
21 Dict,
22 Iterable,
23 List,
24 NewType,
25 Tuple,
26 )
27
28 # Use xdg module when it's less painful to have as a dependency
29
30
31 class XDG(types.SimpleNamespace):
32 XDG_CACHE_HOME: str
33
34
35 xdg = XDG(
36 XDG_CACHE_HOME=os.getenv(
37 'XDG_CACHE_HOME',
38 os.path.expanduser('~/.cache')))
39
40
41 class VerificationError(Exception):
42 pass
43
44
45 class Verification:
46
47 def __init__(self) -> None:
48 self.line_length = 0
49
50 def status(self, s: str) -> None:
51 print(s, end=' ', file=sys.stderr, flush=True)
52 self.line_length += 1 + len(s) # Unicode??
53
54 @staticmethod
55 def _color(s: str, c: int) -> str:
56 return '\033[%2dm%s\033[00m' % (c, s)
57
58 def result(self, r: bool) -> None:
59 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
60 length = len(message)
61 cols = shutil.get_terminal_size().columns or 80
62 pad = (cols - (self.line_length + length)) % cols
63 print(' ' * pad + self._color(message, color), file=sys.stderr)
64 self.line_length = 0
65 if not r:
66 raise VerificationError()
67
68 def check(self, s: str, r: bool) -> None:
69 self.status(s)
70 self.result(r)
71
72 def ok(self) -> None:
73 self.result(True)
74
75
76 Digest16 = NewType('Digest16', str)
77 Digest32 = NewType('Digest32', str)
78
79
80 class ChannelTableEntry(types.SimpleNamespace):
81 absolute_url: str
82 digest: Digest16
83 file: str
84 size: int
85 url: str
86
87
88 class SearchPath(types.SimpleNamespace):
89 release_name: str
90
91
92 class Channel(SearchPath):
93 alias_of: str
94 channel_html: bytes
95 channel_url: str
96 forwarded_url: str
97 git_ref: str
98 git_repo: str
99 git_revision: str
100 old_git_revision: str
101 table: Dict[str, ChannelTableEntry]
102
103
104 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
105
106 def throw(error: OSError) -> None:
107 raise error
108
109 def join(x: str, y: str) -> str:
110 return y if x == '.' else os.path.join(x, y)
111
112 def recursive_files(d: str) -> Iterable[str]:
113 all_files: List[str] = []
114 for path, dirs, files in os.walk(d, onerror=throw):
115 rel = os.path.relpath(path, start=d)
116 all_files.extend(join(rel, f) for f in files)
117 for dir_or_link in dirs:
118 if os.path.islink(join(path, dir_or_link)):
119 all_files.append(join(rel, dir_or_link))
120 return all_files
121
122 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
123 return (f for f in files if not f.startswith('.git/'))
124
125 files = functools.reduce(
126 operator.or_, (set(
127 exclude_dot_git(
128 recursive_files(x))) for x in [a, b]))
129 return filecmp.cmpfiles(a, b, files, shallow=False)
130
131
132 def fetch(v: Verification, channel: Channel) -> None:
133 v.status('Fetching channel')
134 request = urllib.request.urlopen(channel.channel_url, timeout=10)
135 channel.channel_html = request.read()
136 channel.forwarded_url = request.geturl()
137 v.result(request.status == 200) # type: ignore # (for old mypy)
138 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
139
140
141 def parse_channel(v: Verification, channel: Channel) -> None:
142 v.status('Parsing channel description as XML')
143 d = xml.dom.minidom.parseString(channel.channel_html)
144 v.ok()
145
146 v.status('Extracting release name:')
147 title_name = d.getElementsByTagName(
148 'title')[0].firstChild.nodeValue.split()[2]
149 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
150 v.status(title_name)
151 v.result(title_name == h1_name)
152 channel.release_name = title_name
153
154 v.status('Extracting git commit:')
155 git_commit_node = d.getElementsByTagName('tt')[0]
156 channel.git_revision = git_commit_node.firstChild.nodeValue
157 v.status(channel.git_revision)
158 v.ok()
159 v.status('Verifying git commit label')
160 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
161
162 v.status('Parsing table')
163 channel.table = {}
164 for row in d.getElementsByTagName('tr')[1:]:
165 name = row.childNodes[0].firstChild.firstChild.nodeValue
166 url = row.childNodes[0].firstChild.getAttribute('href')
167 size = int(row.childNodes[1].firstChild.nodeValue)
168 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
169 channel.table[name] = ChannelTableEntry(
170 url=url, digest=digest, size=size)
171 v.ok()
172
173
174 def digest_string(s: bytes) -> Digest16:
175 return Digest16(hashlib.sha256(s).hexdigest())
176
177
178 def digest_file(filename: str) -> Digest16:
179 hasher = hashlib.sha256()
180 with open(filename, 'rb') as f:
181 # pylint: disable=cell-var-from-loop
182 for block in iter(lambda: f.read(4096), b''):
183 hasher.update(block)
184 return Digest16(hasher.hexdigest())
185
186
187 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
188 v.status('Converting digest to base16')
189 process = subprocess.run(
190 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
191 v.result(process.returncode == 0)
192 return Digest16(process.stdout.decode().strip())
193
194
195 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
196 v.status('Converting digest to base32')
197 process = subprocess.run(
198 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
199 v.result(process.returncode == 0)
200 return Digest32(process.stdout.decode().strip())
201
202
203 def fetch_with_nix_prefetch_url(
204 v: Verification,
205 url: str,
206 digest: Digest16) -> str:
207 v.status('Fetching %s' % url)
208 process = subprocess.run(
209 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
210 v.result(process.returncode == 0)
211 prefetch_digest, path, empty = process.stdout.decode().split('\n')
212 assert empty == ''
213 v.check("Verifying nix-prefetch-url's digest",
214 to_Digest16(v, Digest32(prefetch_digest)) == digest)
215 v.status("Verifying file digest")
216 file_digest = digest_file(path)
217 v.result(file_digest == digest)
218 return path # type: ignore # (for old mypy)
219
220
221 def fetch_resources(v: Verification, channel: Channel) -> None:
222 for resource in ['git-revision', 'nixexprs.tar.xz']:
223 fields = channel.table[resource]
224 fields.absolute_url = urllib.parse.urljoin(
225 channel.forwarded_url, fields.url)
226 fields.file = fetch_with_nix_prefetch_url(
227 v, fields.absolute_url, fields.digest)
228 v.status('Verifying git commit on main page matches git commit in table')
229 v.result(
230 open(
231 channel.table['git-revision'].file).read(999) == channel.git_revision)
232
233
234 def git_cachedir(git_repo: str) -> str:
235 return os.path.join(
236 xdg.XDG_CACHE_HOME,
237 'pinch/git',
238 digest_string(git_repo.encode()))
239
240
241 def tarball_cache_file(channel: Channel) -> str:
242 return os.path.join(
243 xdg.XDG_CACHE_HOME,
244 'pinch/git-tarball',
245 '%s-%s-%s' %
246 (digest_string(channel.git_repo.encode()),
247 channel.git_revision,
248 channel.release_name))
249
250
251 def verify_git_ancestry(v: Verification, channel: Channel) -> None:
252 cachedir = git_cachedir(channel.git_repo)
253 v.status('Verifying rev is an ancestor of ref')
254 process = subprocess.run(['git',
255 '-C',
256 cachedir,
257 'merge-base',
258 '--is-ancestor',
259 channel.git_revision,
260 channel.git_ref])
261 v.result(process.returncode == 0)
262
263 if hasattr(channel, 'old_git_revision'):
264 v.status(
265 'Verifying rev is an ancestor of previous rev %s' %
266 channel.old_git_revision)
267 process = subprocess.run(['git',
268 '-C',
269 cachedir,
270 'merge-base',
271 '--is-ancestor',
272 channel.old_git_revision,
273 channel.git_revision])
274 v.result(process.returncode == 0)
275
276
277 def git_fetch(v: Verification, channel: Channel) -> None:
278 # It would be nice if we could share the nix git cache, but as of the time
279 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
280 # yet), and trying to straddle them both is too far into nix implementation
281 # details for my comfort. So we re-implement here half of nix.fetchGit.
282 # :(
283
284 cachedir = git_cachedir(channel.git_repo)
285 if not os.path.exists(cachedir):
286 v.status("Initializing git repo")
287 process = subprocess.run(
288 ['git', 'init', '--bare', cachedir])
289 v.result(process.returncode == 0)
290
291 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
292 # We don't use --force here because we want to abort and freak out if forced
293 # updates are happening.
294 process = subprocess.run(['git',
295 '-C',
296 cachedir,
297 'fetch',
298 channel.git_repo,
299 '%s:%s' % (channel.git_ref,
300 channel.git_ref)])
301 v.result(process.returncode == 0)
302
303 if hasattr(channel, 'git_revision'):
304 v.status('Verifying that fetch retrieved this rev')
305 process = subprocess.run(
306 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
307 v.result(process.returncode == 0)
308 else:
309 channel.git_revision = open(
310 os.path.join(
311 cachedir,
312 'refs',
313 'heads',
314 channel.git_ref)).read(999).strip()
315
316 verify_git_ancestry(v, channel)
317
318
319 def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
320 cachedir = git_cachedir(channel.git_repo)
321 if os.path.exists(cachedir):
322 v.status('Checking if we already have this rev:')
323 process = subprocess.run(
324 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
325 if process.returncode == 0:
326 v.status('yes')
327 if process.returncode == 1:
328 v.status('no')
329 v.result(process.returncode == 0 or process.returncode == 1)
330 if process.returncode == 0:
331 verify_git_ancestry(v, channel)
332 return
333 git_fetch(v, channel)
334
335
336 def compare_tarball_and_git(
337 v: Verification,
338 channel: Channel,
339 channel_contents: str,
340 git_contents: str) -> None:
341 v.status('Comparing channel tarball with git checkout')
342 match, mismatch, errors = compare(os.path.join(
343 channel_contents, channel.release_name), git_contents)
344 v.ok()
345 v.check('%d files match' % len(match), len(match) > 0)
346 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
347 expected_errors = [
348 '.git-revision',
349 '.version-suffix',
350 'nixpkgs',
351 'programs.sqlite',
352 'svn-revision']
353 benign_errors = []
354 for ee in expected_errors:
355 if ee in errors:
356 errors.remove(ee)
357 benign_errors.append(ee)
358 v.check(
359 '%d unexpected incomparable files' %
360 len(errors),
361 len(errors) == 0)
362 v.check(
363 '(%d of %d expected incomparable files)' %
364 (len(benign_errors),
365 len(expected_errors)),
366 len(benign_errors) == len(expected_errors))
367
368
369 def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
370 v.status('Extracting tarball %s' %
371 channel.table['nixexprs.tar.xz'].file)
372 shutil.unpack_archive(
373 channel.table['nixexprs.tar.xz'].file,
374 dest)
375 v.ok()
376
377
378 def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
379 v.status('Checking out corresponding git revision')
380 git = subprocess.Popen(['git',
381 '-C',
382 git_cachedir(channel.git_repo),
383 'archive',
384 channel.git_revision],
385 stdout=subprocess.PIPE)
386 tar = subprocess.Popen(
387 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
388 if git.stdout:
389 git.stdout.close()
390 tar.wait()
391 git.wait()
392 v.result(git.returncode == 0 and tar.returncode == 0)
393
394
395 def git_get_tarball(v: Verification, channel: Channel) -> str:
396 cache_file = tarball_cache_file(channel)
397 if os.path.exists(cache_file):
398 cached_tarball = open(cache_file).read(9999)
399 if os.path.exists(cached_tarball):
400 return cached_tarball
401
402 with tempfile.TemporaryDirectory() as output_dir:
403 output_filename = os.path.join(
404 output_dir, channel.release_name + '.tar.xz')
405 with open(output_filename, 'w') as output_file:
406 v.status(
407 'Generating tarball for git revision %s' %
408 channel.git_revision)
409 git = subprocess.Popen(['git',
410 '-C',
411 git_cachedir(channel.git_repo),
412 'archive',
413 '--prefix=%s/' % channel.release_name,
414 channel.git_revision],
415 stdout=subprocess.PIPE)
416 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
417 xz.wait()
418 git.wait()
419 v.result(git.returncode == 0 and xz.returncode == 0)
420
421 v.status('Putting tarball in Nix store')
422 process = subprocess.run(
423 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
424 v.result(process.returncode == 0)
425 store_tarball = process.stdout.decode().strip()
426
427 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
428 open(cache_file, 'w').write(store_tarball)
429 return store_tarball # type: ignore # (for old mypy)
430
431
432 def check_channel_metadata(
433 v: Verification,
434 channel: Channel,
435 channel_contents: str) -> None:
436 v.status('Verifying git commit in channel tarball')
437 v.result(
438 open(
439 os.path.join(
440 channel_contents,
441 channel.release_name,
442 '.git-revision')).read(999) == channel.git_revision)
443
444 v.status(
445 'Verifying version-suffix is a suffix of release name %s:' %
446 channel.release_name)
447 version_suffix = open(
448 os.path.join(
449 channel_contents,
450 channel.release_name,
451 '.version-suffix')).read(999)
452 v.status(version_suffix)
453 v.result(channel.release_name.endswith(version_suffix))
454
455
456 def check_channel_contents(v: Verification, channel: Channel) -> None:
457 with tempfile.TemporaryDirectory() as channel_contents, \
458 tempfile.TemporaryDirectory() as git_contents:
459
460 extract_tarball(v, channel, channel_contents)
461 check_channel_metadata(v, channel, channel_contents)
462
463 git_checkout(v, channel, git_contents)
464
465 compare_tarball_and_git(v, channel, channel_contents, git_contents)
466
467 v.status('Removing temporary directories')
468 v.ok()
469
470
471 def pin_channel(v: Verification, channel: Channel) -> None:
472 fetch(v, channel)
473 parse_channel(v, channel)
474 fetch_resources(v, channel)
475 ensure_git_rev_available(v, channel)
476 check_channel_contents(v, channel)
477
478
479 def git_revision_name(v: Verification, channel: Channel) -> str:
480 v.status('Getting commit date')
481 process = subprocess.run(['git',
482 '-C',
483 git_cachedir(channel.git_repo),
484 'log',
485 '-n1',
486 '--format=%ct-%h',
487 '--abbrev=11',
488 '--no-show-signature',
489 channel.git_revision],
490 stdout=subprocess.PIPE)
491 v.result(process.returncode == 0 and process.stdout != b'')
492 return '%s-%s' % (os.path.basename(channel.git_repo),
493 process.stdout.decode().strip())
494
495
496 def read_search_path(conf: configparser.SectionProxy) -> Channel:
497 return Channel(**dict(conf.items()))
498
499
500 def read_config(filename: str) -> configparser.ConfigParser:
501 config = configparser.ConfigParser()
502 config.read_file(open(filename), filename)
503 return config
504
505
506 def pin(args: argparse.Namespace) -> None:
507 v = Verification()
508 config = read_config(args.channels_file)
509 for section in config.sections():
510 if args.channels and section not in args.channels:
511 continue
512
513 channel = read_search_path(config[section])
514
515 if hasattr(channel, 'alias_of'):
516 assert not hasattr(channel, 'git_repo')
517 continue
518
519 if hasattr(channel, 'git_revision'):
520 channel.old_git_revision = channel.git_revision
521 del channel.git_revision
522
523 if 'channel_url' in config[section]:
524 pin_channel(v, channel)
525 config[section]['release_name'] = channel.release_name
526 config[section]['tarball_url'] = channel.table['nixexprs.tar.xz'].absolute_url
527 config[section]['tarball_sha256'] = channel.table['nixexprs.tar.xz'].digest
528 else:
529 git_fetch(v, channel)
530 config[section]['release_name'] = git_revision_name(v, channel)
531 config[section]['git_revision'] = channel.git_revision
532
533 with open(args.channels_file, 'w') as configfile:
534 config.write(configfile)
535
536
537 def fetch_channel(
538 v: Verification,
539 section: str,
540 conf: configparser.SectionProxy) -> str:
541 if 'git_repo' not in conf or 'release_name' not in conf:
542 raise Exception(
543 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
544 section)
545
546 if 'channel_url' in conf:
547 return fetch_with_nix_prefetch_url(
548 v, conf['tarball_url'], Digest16(
549 conf['tarball_sha256']))
550
551 channel = read_search_path(conf)
552 ensure_git_rev_available(v, channel)
553 return git_get_tarball(v, channel)
554
555
556 def update(args: argparse.Namespace) -> None:
557 v = Verification()
558 config = configparser.ConfigParser()
559 exprs: Dict[str, str] = {}
560 configs = [read_config(filename) for filename in args.channels_file]
561 for config in configs:
562 for section in config.sections():
563 if 'alias_of' in config[section]:
564 assert 'git_repo' not in config[section]
565 continue
566 tarball = fetch_channel(v, section, config[section])
567 if section in exprs:
568 raise Exception('Duplicate channel "%s"' % section)
569 exprs[section] = (
570 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
571 (config[section]['release_name'], tarball))
572
573 for config in configs:
574 for section in config.sections():
575 if 'alias_of' in config[section]:
576 if section in exprs:
577 raise Exception('Duplicate channel "%s"' % section)
578 exprs[section] = exprs[str(config[section]['alias_of'])]
579
580 command = [
581 'nix-env',
582 '--profile',
583 '/nix/var/nix/profiles/per-user/%s/channels' %
584 getpass.getuser(),
585 '--show-trace',
586 '--file',
587 '<nix/unpack-channel.nix>',
588 '--install',
589 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
590 if args.dry_run:
591 print(' '.join(map(shlex.quote, command)))
592 else:
593 v.status('Installing channels with nix-env')
594 process = subprocess.run(command)
595 v.result(process.returncode == 0)
596
597
598 def main() -> None:
599 parser = argparse.ArgumentParser(prog='pinch')
600 subparsers = parser.add_subparsers(dest='mode', required=True)
601 parser_pin = subparsers.add_parser('pin')
602 parser_pin.add_argument('channels_file', type=str)
603 parser_pin.add_argument('channels', type=str, nargs='*')
604 parser_pin.set_defaults(func=pin)
605 parser_update = subparsers.add_parser('update')
606 parser_update.add_argument('--dry-run', action='store_true')
607 parser_update.add_argument('channels_file', type=str, nargs='+')
608 parser_update.set_defaults(func=update)
609 args = parser.parse_args()
610 args.func(args)
611
612
613 if __name__ == '__main__':
614 main()