]> git.scottworley.com Git - pinch/blob - pinch.py
732a91f645c036a4692e6ea55083cad4427373d0
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tempfile
15 import types
16 import urllib.parse
17 import urllib.request
18 import xml.dom.minidom
19
20 from typing import (
21 Dict,
22 Iterable,
23 List,
24 NewType,
25 Tuple,
26 )
27
28 import xdg
29
30
31 Digest16 = NewType('Digest16', str)
32 Digest32 = NewType('Digest32', str)
33
34
35 class ChannelTableEntry(types.SimpleNamespace):
36 absolute_url: str
37 digest: Digest16
38 file: str
39 size: int
40 url: str
41
42
43 class Channel(types.SimpleNamespace):
44 alias_of: str
45 channel_html: bytes
46 channel_url: str
47 forwarded_url: str
48 git_ref: str
49 git_repo: str
50 git_revision: str
51 old_git_revision: str
52 release_name: str
53 table: Dict[str, ChannelTableEntry]
54
55
56 class VerificationError(Exception):
57 pass
58
59
60 class Verification:
61
62 def __init__(self) -> None:
63 self.line_length = 0
64
65 def status(self, s: str) -> None:
66 print(s, end=' ', file=sys.stderr, flush=True)
67 self.line_length += 1 + len(s) # Unicode??
68
69 @staticmethod
70 def _color(s: str, c: int) -> str:
71 return '\033[%2dm%s\033[00m' % (c, s)
72
73 def result(self, r: bool) -> None:
74 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
75 length = len(message)
76 cols = shutil.get_terminal_size().columns
77 pad = (cols - (self.line_length + length)) % cols
78 print(' ' * pad + self._color(message, color), file=sys.stderr)
79 self.line_length = 0
80 if not r:
81 raise VerificationError()
82
83 def check(self, s: str, r: bool) -> None:
84 self.status(s)
85 self.result(r)
86
87 def ok(self) -> None:
88 self.result(True)
89
90
91 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
92
93 def throw(error: OSError) -> None:
94 raise error
95
96 def join(x: str, y: str) -> str:
97 return y if x == '.' else os.path.join(x, y)
98
99 def recursive_files(d: str) -> Iterable[str]:
100 all_files: List[str] = []
101 for path, dirs, files in os.walk(d, onerror=throw):
102 rel = os.path.relpath(path, start=d)
103 all_files.extend(join(rel, f) for f in files)
104 for dir_or_link in dirs:
105 if os.path.islink(join(path, dir_or_link)):
106 all_files.append(join(rel, dir_or_link))
107 return all_files
108
109 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
110 return (f for f in files if not f.startswith('.git/'))
111
112 files = functools.reduce(
113 operator.or_, (set(
114 exclude_dot_git(
115 recursive_files(x))) for x in [a, b]))
116 return filecmp.cmpfiles(a, b, files, shallow=False)
117
118
119 def fetch(v: Verification, channel: Channel) -> None:
120 v.status('Fetching channel')
121 request = urllib.request.urlopen(channel.channel_url, timeout=10)
122 channel.channel_html = request.read()
123 channel.forwarded_url = request.geturl()
124 v.result(request.status == 200)
125 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
126
127
128 def parse_channel(v: Verification, channel: Channel) -> None:
129 v.status('Parsing channel description as XML')
130 d = xml.dom.minidom.parseString(channel.channel_html)
131 v.ok()
132
133 v.status('Extracting release name:')
134 title_name = d.getElementsByTagName(
135 'title')[0].firstChild.nodeValue.split()[2]
136 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
137 v.status(title_name)
138 v.result(title_name == h1_name)
139 channel.release_name = title_name
140
141 v.status('Extracting git commit:')
142 git_commit_node = d.getElementsByTagName('tt')[0]
143 channel.git_revision = git_commit_node.firstChild.nodeValue
144 v.status(channel.git_revision)
145 v.ok()
146 v.status('Verifying git commit label')
147 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
148
149 v.status('Parsing table')
150 channel.table = {}
151 for row in d.getElementsByTagName('tr')[1:]:
152 name = row.childNodes[0].firstChild.firstChild.nodeValue
153 url = row.childNodes[0].firstChild.getAttribute('href')
154 size = int(row.childNodes[1].firstChild.nodeValue)
155 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
156 channel.table[name] = ChannelTableEntry(
157 url=url, digest=digest, size=size)
158 v.ok()
159
160
161 def digest_string(s: bytes) -> Digest16:
162 return Digest16(hashlib.sha256(s).hexdigest())
163
164
165 def digest_file(filename: str) -> Digest16:
166 hasher = hashlib.sha256()
167 with open(filename, 'rb') as f:
168 # pylint: disable=cell-var-from-loop
169 for block in iter(lambda: f.read(4096), b''):
170 hasher.update(block)
171 return Digest16(hasher.hexdigest())
172
173
174 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
175 v.status('Converting digest to base16')
176 process = subprocess.run(
177 ['nix', 'to-base16', '--type', 'sha256', digest32], capture_output=True)
178 v.result(process.returncode == 0)
179 return Digest16(process.stdout.decode().strip())
180
181
182 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
183 v.status('Converting digest to base32')
184 process = subprocess.run(
185 ['nix', 'to-base32', '--type', 'sha256', digest16], capture_output=True)
186 v.result(process.returncode == 0)
187 return Digest32(process.stdout.decode().strip())
188
189
190 def fetch_with_nix_prefetch_url(
191 v: Verification,
192 url: str,
193 digest: Digest16) -> str:
194 v.status('Fetching %s' % url)
195 process = subprocess.run(
196 ['nix-prefetch-url', '--print-path', url, digest], capture_output=True)
197 v.result(process.returncode == 0)
198 prefetch_digest, path, empty = process.stdout.decode().split('\n')
199 assert empty == ''
200 v.check("Verifying nix-prefetch-url's digest",
201 to_Digest16(v, Digest32(prefetch_digest)) == digest)
202 v.status("Verifying file digest")
203 file_digest = digest_file(path)
204 v.result(file_digest == digest)
205 return path
206
207
208 def fetch_resources(v: Verification, channel: Channel) -> None:
209 for resource in ['git-revision', 'nixexprs.tar.xz']:
210 fields = channel.table[resource]
211 fields.absolute_url = urllib.parse.urljoin(
212 channel.forwarded_url, fields.url)
213 fields.file = fetch_with_nix_prefetch_url(
214 v, fields.absolute_url, fields.digest)
215 v.status('Verifying git commit on main page matches git commit in table')
216 v.result(
217 open(
218 channel.table['git-revision'].file).read(999) == channel.git_revision)
219
220
221 def git_cachedir(git_repo: str) -> str:
222 return os.path.join(
223 xdg.XDG_CACHE_HOME,
224 'pinch/git',
225 digest_string(git_repo.encode()))
226
227
228 def tarball_cache_file(channel: Channel) -> str:
229 return os.path.join(
230 xdg.XDG_CACHE_HOME,
231 'pinch/git-tarball',
232 '%s-%s-%s' %
233 (digest_string(channel.git_repo.encode()),
234 channel.git_revision,
235 channel.release_name))
236
237
238 def verify_git_ancestry(v: Verification, channel: Channel) -> None:
239 cachedir = git_cachedir(channel.git_repo)
240 v.status('Verifying rev is an ancestor of ref')
241 process = subprocess.run(['git',
242 '-C',
243 cachedir,
244 'merge-base',
245 '--is-ancestor',
246 channel.git_revision,
247 channel.git_ref])
248 v.result(process.returncode == 0)
249
250 if hasattr(channel, 'old_git_revision'):
251 v.status(
252 'Verifying rev is an ancestor of previous rev %s' %
253 channel.old_git_revision)
254 process = subprocess.run(['git',
255 '-C',
256 cachedir,
257 'merge-base',
258 '--is-ancestor',
259 channel.old_git_revision,
260 channel.git_revision])
261 v.result(process.returncode == 0)
262
263
264 def git_fetch(v: Verification, channel: Channel) -> None:
265 # It would be nice if we could share the nix git cache, but as of the time
266 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
267 # yet), and trying to straddle them both is too far into nix implementation
268 # details for my comfort. So we re-implement here half of nix.fetchGit.
269 # :(
270
271 cachedir = git_cachedir(channel.git_repo)
272 if not os.path.exists(cachedir):
273 v.status("Initializing git repo")
274 process = subprocess.run(
275 ['git', 'init', '--bare', cachedir])
276 v.result(process.returncode == 0)
277
278 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
279 # We don't use --force here because we want to abort and freak out if forced
280 # updates are happening.
281 process = subprocess.run(['git',
282 '-C',
283 cachedir,
284 'fetch',
285 channel.git_repo,
286 '%s:%s' % (channel.git_ref,
287 channel.git_ref)])
288 v.result(process.returncode == 0)
289
290 if hasattr(channel, 'git_revision'):
291 v.status('Verifying that fetch retrieved this rev')
292 process = subprocess.run(
293 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
294 v.result(process.returncode == 0)
295 else:
296 channel.git_revision = open(
297 os.path.join(
298 cachedir,
299 'refs',
300 'heads',
301 channel.git_ref)).read(999).strip()
302
303 verify_git_ancestry(v, channel)
304
305
306 def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
307 cachedir = git_cachedir(channel.git_repo)
308 if os.path.exists(cachedir):
309 v.status('Checking if we already have this rev:')
310 process = subprocess.run(
311 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
312 if process.returncode == 0:
313 v.status('yes')
314 if process.returncode == 1:
315 v.status('no')
316 v.result(process.returncode == 0 or process.returncode == 1)
317 if process.returncode == 0:
318 verify_git_ancestry(v, channel)
319 return
320 git_fetch(v, channel)
321
322
323 def compare_tarball_and_git(
324 v: Verification,
325 channel: Channel,
326 channel_contents: str,
327 git_contents: str) -> None:
328 v.status('Comparing channel tarball with git checkout')
329 match, mismatch, errors = compare(os.path.join(
330 channel_contents, channel.release_name), git_contents)
331 v.ok()
332 v.check('%d files match' % len(match), len(match) > 0)
333 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
334 expected_errors = [
335 '.git-revision',
336 '.version-suffix',
337 'nixpkgs',
338 'programs.sqlite',
339 'svn-revision']
340 benign_errors = []
341 for ee in expected_errors:
342 if ee in errors:
343 errors.remove(ee)
344 benign_errors.append(ee)
345 v.check(
346 '%d unexpected incomparable files' %
347 len(errors),
348 len(errors) == 0)
349 v.check(
350 '(%d of %d expected incomparable files)' %
351 (len(benign_errors),
352 len(expected_errors)),
353 len(benign_errors) == len(expected_errors))
354
355
356 def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
357 v.status('Extracting tarball %s' %
358 channel.table['nixexprs.tar.xz'].file)
359 shutil.unpack_archive(
360 channel.table['nixexprs.tar.xz'].file,
361 dest)
362 v.ok()
363
364
365 def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
366 v.status('Checking out corresponding git revision')
367 git = subprocess.Popen(['git',
368 '-C',
369 git_cachedir(channel.git_repo),
370 'archive',
371 channel.git_revision],
372 stdout=subprocess.PIPE)
373 tar = subprocess.Popen(
374 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
375 if git.stdout:
376 git.stdout.close()
377 tar.wait()
378 git.wait()
379 v.result(git.returncode == 0 and tar.returncode == 0)
380
381
382 def git_get_tarball(v: Verification, channel: Channel) -> str:
383 cache_file = tarball_cache_file(channel)
384 if os.path.exists(cache_file):
385 cached_tarball = open(cache_file).read(9999)
386 if os.path.exists(cached_tarball):
387 return cached_tarball
388
389 with tempfile.TemporaryDirectory() as output_dir:
390 output_filename = os.path.join(
391 output_dir, channel.release_name + '.tar.xz')
392 with open(output_filename, 'w') as output_file:
393 v.status(
394 'Generating tarball for git revision %s' %
395 channel.git_revision)
396 git = subprocess.Popen(['git',
397 '-C',
398 git_cachedir(channel.git_repo),
399 'archive',
400 '--prefix=%s/' % channel.release_name,
401 channel.git_revision],
402 stdout=subprocess.PIPE)
403 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
404 xz.wait()
405 git.wait()
406 v.result(git.returncode == 0 and xz.returncode == 0)
407
408 v.status('Putting tarball in Nix store')
409 process = subprocess.run(
410 ['nix-store', '--add', output_filename], capture_output=True)
411 v.result(process.returncode == 0)
412 store_tarball = process.stdout.decode().strip()
413
414 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
415 open(cache_file, 'w').write(store_tarball)
416 return store_tarball
417
418
419 def check_channel_metadata(
420 v: Verification,
421 channel: Channel,
422 channel_contents: str) -> None:
423 v.status('Verifying git commit in channel tarball')
424 v.result(
425 open(
426 os.path.join(
427 channel_contents,
428 channel.release_name,
429 '.git-revision')).read(999) == channel.git_revision)
430
431 v.status(
432 'Verifying version-suffix is a suffix of release name %s:' %
433 channel.release_name)
434 version_suffix = open(
435 os.path.join(
436 channel_contents,
437 channel.release_name,
438 '.version-suffix')).read(999)
439 v.status(version_suffix)
440 v.result(channel.release_name.endswith(version_suffix))
441
442
443 def check_channel_contents(v: Verification, channel: Channel) -> None:
444 with tempfile.TemporaryDirectory() as channel_contents, \
445 tempfile.TemporaryDirectory() as git_contents:
446
447 extract_tarball(v, channel, channel_contents)
448 check_channel_metadata(v, channel, channel_contents)
449
450 git_checkout(v, channel, git_contents)
451
452 compare_tarball_and_git(v, channel, channel_contents, git_contents)
453
454 v.status('Removing temporary directories')
455 v.ok()
456
457
458 def pin_channel(v: Verification, channel: Channel) -> None:
459 fetch(v, channel)
460 parse_channel(v, channel)
461 fetch_resources(v, channel)
462 ensure_git_rev_available(v, channel)
463 check_channel_contents(v, channel)
464
465
466 def git_revision_name(v: Verification, channel: Channel) -> str:
467 v.status('Getting commit date')
468 process = subprocess.run(['git',
469 '-C',
470 git_cachedir(channel.git_repo),
471 'lo',
472 '-n1',
473 '--format=%ct-%h',
474 '--abbrev=11',
475 channel.git_revision],
476 capture_output=True)
477 v.result(process.returncode == 0 and process.stdout != b'')
478 return '%s-%s' % (os.path.basename(channel.git_repo),
479 process.stdout.decode().strip())
480
481
482 def read_config(filename: str) -> configparser.ConfigParser:
483 config = configparser.ConfigParser()
484 config.read_file(open(filename), filename)
485 return config
486
487
488 def pin(args: argparse.Namespace) -> None:
489 v = Verification()
490 config = read_config(args.channels_file)
491 for section in config.sections():
492 if args.channels and section not in args.channels:
493 continue
494
495 channel = Channel(**dict(config[section].items()))
496
497 if hasattr(channel, 'alias_of'):
498 assert not hasattr(channel, 'git_repo')
499 continue
500
501 if hasattr(channel, 'git_revision'):
502 channel.old_git_revision = channel.git_revision
503 del channel.git_revision
504
505 if 'channel_url' in config[section]:
506 pin_channel(v, channel)
507 config[section]['release_name'] = channel.release_name
508 config[section]['tarball_url'] = channel.table['nixexprs.tar.xz'].absolute_url
509 config[section]['tarball_sha256'] = channel.table['nixexprs.tar.xz'].digest
510 else:
511 git_fetch(v, channel)
512 config[section]['release_name'] = git_revision_name(v, channel)
513 config[section]['git_revision'] = channel.git_revision
514
515 with open(args.channels_file, 'w') as configfile:
516 config.write(configfile)
517
518
519 def fetch_channel(
520 v: Verification,
521 section: str,
522 conf: configparser.SectionProxy) -> str:
523 if 'git_repo' not in conf or 'release_name' not in conf:
524 raise Exception(
525 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
526 section)
527
528 if 'channel_url' in conf:
529 return fetch_with_nix_prefetch_url(
530 v, conf['tarball_url'], Digest16(
531 conf['tarball_sha256']))
532
533 channel = Channel(**dict(conf.items()))
534 ensure_git_rev_available(v, channel)
535 return git_get_tarball(v, channel)
536
537
538 def update(args: argparse.Namespace) -> None:
539 v = Verification()
540 config = configparser.ConfigParser()
541 exprs: Dict[str, str] = {}
542 configs = [read_config(filename) for filename in args.channels_file]
543 for config in configs:
544 for section in config.sections():
545 if 'alias_of' in config[section]:
546 assert 'git_repo' not in config[section]
547 continue
548 tarball = fetch_channel(v, section, config[section])
549 if section in exprs:
550 raise Exception('Duplicate channel "%s"' % section)
551 exprs[section] = (
552 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
553 (config[section]['release_name'], tarball))
554
555 for config in configs:
556 for section in config.sections():
557 if 'alias_of' in config[section]:
558 if section in exprs:
559 raise Exception('Duplicate channel "%s"' % section)
560 exprs[section] = exprs[str(config[section]['alias_of'])]
561
562 command = [
563 'nix-env',
564 '--profile',
565 '/nix/var/nix/profiles/per-user/%s/channels' %
566 getpass.getuser(),
567 '--show-trace',
568 '--file',
569 '<nix/unpack-channel.nix>',
570 '--install',
571 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
572 if args.dry_run:
573 print(' '.join(map(shlex.quote, command)))
574 else:
575 v.status('Installing channels with nix-env')
576 process = subprocess.run(command)
577 v.result(process.returncode == 0)
578
579
580 def main() -> None:
581 parser = argparse.ArgumentParser(prog='pinch')
582 subparsers = parser.add_subparsers(dest='mode', required=True)
583 parser_pin = subparsers.add_parser('pin')
584 parser_pin.add_argument('channels_file', type=str)
585 parser_pin.add_argument('channels', type=str, nargs='*')
586 parser_pin.set_defaults(func=pin)
587 parser_update = subparsers.add_parser('update')
588 parser_update.add_argument('--dry-run', action='store_true')
589 parser_update.add_argument('channels_file', type=str, nargs='+')
590 parser_update.set_defaults(func=update)
591 args = parser.parse_args()
592 args.func(args)
593
594
595 main()