]> git.scottworley.com Git - pinch/blob - pinch.py
b1471007f9399db512013a95a226210ef8cdf04c
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tempfile
15 import types
16 import urllib.parse
17 import urllib.request
18 import xml.dom.minidom
19
20 from typing import (
21 Dict,
22 Iterable,
23 List,
24 NewType,
25 Tuple,
26 )
27
28 Digest16 = NewType('Digest16', str)
29 Digest32 = NewType('Digest32', str)
30
31
32 class ChannelTableEntry(types.SimpleNamespace):
33 absolute_url: str
34 digest: Digest16
35 file: str
36 size: int
37 url: str
38
39
40 class Channel(types.SimpleNamespace):
41 alias_of: str
42 channel_html: bytes
43 channel_url: str
44 forwarded_url: str
45 git_ref: str
46 git_repo: str
47 git_revision: str
48 old_git_revision: str
49 release_name: str
50 table: Dict[str, ChannelTableEntry]
51
52
53 class VerificationError(Exception):
54 pass
55
56
57 class Verification:
58
59 def __init__(self) -> None:
60 self.line_length = 0
61
62 def status(self, s: str) -> None:
63 print(s, end=' ', file=sys.stderr, flush=True)
64 self.line_length += 1 + len(s) # Unicode??
65
66 @staticmethod
67 def _color(s: str, c: int) -> str:
68 return '\033[%2dm%s\033[00m' % (c, s)
69
70 def result(self, r: bool) -> None:
71 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
72 length = len(message)
73 cols = shutil.get_terminal_size().columns
74 pad = (cols - (self.line_length + length)) % cols
75 print(' ' * pad + self._color(message, color), file=sys.stderr)
76 self.line_length = 0
77 if not r:
78 raise VerificationError()
79
80 def check(self, s: str, r: bool) -> None:
81 self.status(s)
82 self.result(r)
83
84 def ok(self) -> None:
85 self.result(True)
86
87
88 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
89
90 def throw(error: OSError) -> None:
91 raise error
92
93 def join(x: str, y: str) -> str:
94 return y if x == '.' else os.path.join(x, y)
95
96 def recursive_files(d: str) -> Iterable[str]:
97 all_files: List[str] = []
98 for path, dirs, files in os.walk(d, onerror=throw):
99 rel = os.path.relpath(path, start=d)
100 all_files.extend(join(rel, f) for f in files)
101 for dir_or_link in dirs:
102 if os.path.islink(join(path, dir_or_link)):
103 all_files.append(join(rel, dir_or_link))
104 return all_files
105
106 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
107 return (f for f in files if not f.startswith('.git/'))
108
109 files = functools.reduce(
110 operator.or_, (set(
111 exclude_dot_git(
112 recursive_files(x))) for x in [a, b]))
113 return filecmp.cmpfiles(a, b, files, shallow=False)
114
115
116 def fetch(v: Verification, channel: Channel) -> None:
117 v.status('Fetching channel')
118 request = urllib.request.urlopen(channel.channel_url, timeout=10)
119 channel.channel_html = request.read()
120 channel.forwarded_url = request.geturl()
121 v.result(request.status == 200)
122 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
123
124
125 def parse_channel(v: Verification, channel: Channel) -> None:
126 v.status('Parsing channel description as XML')
127 d = xml.dom.minidom.parseString(channel.channel_html)
128 v.ok()
129
130 v.status('Extracting release name:')
131 title_name = d.getElementsByTagName(
132 'title')[0].firstChild.nodeValue.split()[2]
133 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
134 v.status(title_name)
135 v.result(title_name == h1_name)
136 channel.release_name = title_name
137
138 v.status('Extracting git commit:')
139 git_commit_node = d.getElementsByTagName('tt')[0]
140 channel.git_revision = git_commit_node.firstChild.nodeValue
141 v.status(channel.git_revision)
142 v.ok()
143 v.status('Verifying git commit label')
144 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
145
146 v.status('Parsing table')
147 channel.table = {}
148 for row in d.getElementsByTagName('tr')[1:]:
149 name = row.childNodes[0].firstChild.firstChild.nodeValue
150 url = row.childNodes[0].firstChild.getAttribute('href')
151 size = int(row.childNodes[1].firstChild.nodeValue)
152 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
153 channel.table[name] = ChannelTableEntry(
154 url=url, digest=digest, size=size)
155 v.ok()
156
157
158 def digest_string(s: bytes) -> Digest16:
159 return Digest16(hashlib.sha256(s).hexdigest())
160
161
162 def digest_file(filename: str) -> Digest16:
163 hasher = hashlib.sha256()
164 with open(filename, 'rb') as f:
165 # pylint: disable=cell-var-from-loop
166 for block in iter(lambda: f.read(4096), b''):
167 hasher.update(block)
168 return Digest16(hasher.hexdigest())
169
170
171 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
172 v.status('Converting digest to base16')
173 process = subprocess.run(
174 ['nix', 'to-base16', '--type', 'sha256', digest32], capture_output=True)
175 v.result(process.returncode == 0)
176 return Digest16(process.stdout.decode().strip())
177
178
179 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
180 v.status('Converting digest to base32')
181 process = subprocess.run(
182 ['nix', 'to-base32', '--type', 'sha256', digest16], capture_output=True)
183 v.result(process.returncode == 0)
184 return Digest32(process.stdout.decode().strip())
185
186
187 def fetch_with_nix_prefetch_url(
188 v: Verification,
189 url: str,
190 digest: Digest16) -> str:
191 v.status('Fetching %s' % url)
192 process = subprocess.run(
193 ['nix-prefetch-url', '--print-path', url, digest], capture_output=True)
194 v.result(process.returncode == 0)
195 prefetch_digest, path, empty = process.stdout.decode().split('\n')
196 assert empty == ''
197 v.check("Verifying nix-prefetch-url's digest",
198 to_Digest16(v, Digest32(prefetch_digest)) == digest)
199 v.status("Verifying file digest")
200 file_digest = digest_file(path)
201 v.result(file_digest == digest)
202 return path
203
204
205 def fetch_resources(v: Verification, channel: Channel) -> None:
206 for resource in ['git-revision', 'nixexprs.tar.xz']:
207 fields = channel.table[resource]
208 fields.absolute_url = urllib.parse.urljoin(
209 channel.forwarded_url, fields.url)
210 fields.file = fetch_with_nix_prefetch_url(
211 v, fields.absolute_url, fields.digest)
212 v.status('Verifying git commit on main page matches git commit in table')
213 v.result(
214 open(
215 channel.table['git-revision'].file).read(999) == channel.git_revision)
216
217
218 def git_cachedir(git_repo: str) -> str:
219 # TODO: Consider using pyxdg to find this path.
220 return os.path.expanduser(
221 '~/.cache/pinch/git/%s' %
222 digest_string(
223 git_repo.encode()))
224
225
226 def verify_git_ancestry(v: Verification, channel: Channel) -> None:
227 cachedir = git_cachedir(channel.git_repo)
228 v.status('Verifying rev is an ancestor of ref')
229 process = subprocess.run(['git',
230 '-C',
231 cachedir,
232 'merge-base',
233 '--is-ancestor',
234 channel.git_revision,
235 channel.git_ref])
236 v.result(process.returncode == 0)
237
238 if hasattr(channel, 'old_git_revision'):
239 v.status(
240 'Verifying rev is an ancestor of previous rev %s' %
241 channel.old_git_revision)
242 process = subprocess.run(['git',
243 '-C',
244 cachedir,
245 'merge-base',
246 '--is-ancestor',
247 channel.old_git_revision,
248 channel.git_revision])
249 v.result(process.returncode == 0)
250
251
252 def git_fetch(v: Verification, channel: Channel) -> None:
253 # It would be nice if we could share the nix git cache, but as of the time
254 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
255 # yet), and trying to straddle them both is too far into nix implementation
256 # details for my comfort. So we re-implement here half of nix.fetchGit.
257 # :(
258
259 cachedir = git_cachedir(channel.git_repo)
260 if not os.path.exists(cachedir):
261 v.status("Initializing git repo")
262 process = subprocess.run(
263 ['git', 'init', '--bare', cachedir])
264 v.result(process.returncode == 0)
265
266 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
267 # We don't use --force here because we want to abort and freak out if forced
268 # updates are happening.
269 process = subprocess.run(['git',
270 '-C',
271 cachedir,
272 'fetch',
273 channel.git_repo,
274 '%s:%s' % (channel.git_ref,
275 channel.git_ref)])
276 v.result(process.returncode == 0)
277
278 if hasattr(channel, 'git_revision'):
279 v.status('Verifying that fetch retrieved this rev')
280 process = subprocess.run(
281 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
282 v.result(process.returncode == 0)
283 else:
284 channel.git_revision = open(
285 os.path.join(
286 cachedir,
287 'refs',
288 'heads',
289 channel.git_ref)).read(999).strip()
290
291 verify_git_ancestry(v, channel)
292
293
294 def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
295 cachedir = git_cachedir(channel.git_repo)
296 if os.path.exists(cachedir):
297 v.status('Checking if we already have this rev:')
298 process = subprocess.run(
299 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
300 if process.returncode == 0:
301 v.status('yes')
302 if process.returncode == 1:
303 v.status('no')
304 v.result(process.returncode == 0 or process.returncode == 1)
305 if process.returncode == 0:
306 verify_git_ancestry(v, channel)
307 return
308 git_fetch(v, channel)
309
310
311 def compare_tarball_and_git(
312 v: Verification,
313 channel: Channel,
314 channel_contents: str,
315 git_contents: str) -> None:
316 v.status('Comparing channel tarball with git checkout')
317 match, mismatch, errors = compare(os.path.join(
318 channel_contents, channel.release_name), git_contents)
319 v.ok()
320 v.check('%d files match' % len(match), len(match) > 0)
321 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
322 expected_errors = [
323 '.git-revision',
324 '.version-suffix',
325 'nixpkgs',
326 'programs.sqlite',
327 'svn-revision']
328 benign_errors = []
329 for ee in expected_errors:
330 if ee in errors:
331 errors.remove(ee)
332 benign_errors.append(ee)
333 v.check(
334 '%d unexpected incomparable files' %
335 len(errors),
336 len(errors) == 0)
337 v.check(
338 '(%d of %d expected incomparable files)' %
339 (len(benign_errors),
340 len(expected_errors)),
341 len(benign_errors) == len(expected_errors))
342
343
344 def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
345 v.status('Extracting tarball %s' %
346 channel.table['nixexprs.tar.xz'].file)
347 shutil.unpack_archive(
348 channel.table['nixexprs.tar.xz'].file,
349 dest)
350 v.ok()
351
352
353 def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
354 v.status('Checking out corresponding git revision')
355 git = subprocess.Popen(['git',
356 '-C',
357 git_cachedir(channel.git_repo),
358 'archive',
359 channel.git_revision],
360 stdout=subprocess.PIPE)
361 tar = subprocess.Popen(
362 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
363 git.stdout.close()
364 tar.wait()
365 git.wait()
366 v.result(git.returncode == 0 and tar.returncode == 0)
367
368
369 def git_get_tarball(v: Verification, channel: Channel) -> str:
370 with tempfile.TemporaryDirectory() as output_dir:
371 output_filename = os.path.join(
372 output_dir, channel.release_name + '.tar.xz')
373 with open(output_filename, 'w') as output_file:
374 v.status(
375 'Generating tarball for git revision %s' %
376 channel.git_revision)
377 git = subprocess.Popen(['git',
378 '-C',
379 git_cachedir(channel.git_repo),
380 'archive',
381 '--prefix=%s/' % channel.release_name,
382 channel.git_revision],
383 stdout=subprocess.PIPE)
384 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
385 xz.wait()
386 git.wait()
387 v.result(git.returncode == 0 and xz.returncode == 0)
388
389 v.status('Putting tarball in Nix store')
390 process = subprocess.run(
391 ['nix-store', '--add', output_filename], capture_output=True)
392 v.result(process.returncode == 0)
393 return process.stdout.decode().strip()
394
395
396 def check_channel_metadata(
397 v: Verification,
398 channel: Channel,
399 channel_contents: str) -> None:
400 v.status('Verifying git commit in channel tarball')
401 v.result(
402 open(
403 os.path.join(
404 channel_contents,
405 channel.release_name,
406 '.git-revision')).read(999) == channel.git_revision)
407
408 v.status(
409 'Verifying version-suffix is a suffix of release name %s:' %
410 channel.release_name)
411 version_suffix = open(
412 os.path.join(
413 channel_contents,
414 channel.release_name,
415 '.version-suffix')).read(999)
416 v.status(version_suffix)
417 v.result(channel.release_name.endswith(version_suffix))
418
419
420 def check_channel_contents(v: Verification, channel: Channel) -> None:
421 with tempfile.TemporaryDirectory() as channel_contents, \
422 tempfile.TemporaryDirectory() as git_contents:
423
424 extract_tarball(v, channel, channel_contents)
425 check_channel_metadata(v, channel, channel_contents)
426
427 git_checkout(v, channel, git_contents)
428
429 compare_tarball_and_git(v, channel, channel_contents, git_contents)
430
431 v.status('Removing temporary directories')
432 v.ok()
433
434
435 def pin_channel(v: Verification, channel: Channel) -> None:
436 fetch(v, channel)
437 parse_channel(v, channel)
438 fetch_resources(v, channel)
439 ensure_git_rev_available(v, channel)
440 check_channel_contents(v, channel)
441
442
443 def git_revision_name(v: Verification, channel: Channel) -> str:
444 v.status('Getting commit date')
445 process = subprocess.run(['git',
446 '-C',
447 git_cachedir(channel.git_repo),
448 'lo',
449 '-n1',
450 '--format=%ct-%h',
451 '--abbrev=11',
452 channel.git_revision],
453 capture_output=True)
454 v.result(process.returncode == 0 and process.stdout != '')
455 return '%s-%s' % (os.path.basename(channel.git_repo),
456 process.stdout.decode().strip())
457
458
459 def pin(args: argparse.Namespace) -> None:
460 v = Verification()
461 config = configparser.ConfigParser()
462 config.read_file(open(args.channels_file), args.channels_file)
463 for section in config.sections():
464 if args.channels and section not in args.channels:
465 continue
466
467 channel = Channel(**dict(config[section].items()))
468
469 if hasattr(channel, 'alias_of'):
470 assert not hasattr(channel, 'git_repo')
471 continue
472
473 if hasattr(channel, 'git_revision'):
474 channel.old_git_revision = channel.git_revision
475 del channel.git_revision
476
477 if 'channel_url' in config[section]:
478 pin_channel(v, channel)
479 config[section]['release_name'] = channel.release_name
480 config[section]['tarball_url'] = channel.table['nixexprs.tar.xz'].absolute_url
481 config[section]['tarball_sha256'] = channel.table['nixexprs.tar.xz'].digest
482 else:
483 git_fetch(v, channel)
484 config[section]['release_name'] = git_revision_name(v, channel)
485 config[section]['git_revision'] = channel.git_revision
486
487 with open(args.channels_file, 'w') as configfile:
488 config.write(configfile)
489
490
491 def update(args: argparse.Namespace) -> None:
492 v = Verification()
493 config = configparser.ConfigParser()
494 config.read_file(open(args.channels_file), args.channels_file)
495 exprs = {}
496 for section in config.sections():
497
498 if 'alias_of' in config[section]:
499 assert 'git_repo' not in config[section]
500 continue
501
502 if 'channel_url' in config[section]:
503 tarball = fetch_with_nix_prefetch_url(
504 v, config[section]['tarball_url'], Digest16(
505 config[section]['tarball_sha256']))
506 else:
507 channel = Channel(**dict(config[section].items()))
508 ensure_git_rev_available(v, channel)
509 tarball = git_get_tarball(v, channel)
510
511 exprs[section] = (
512 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
513 (config[section]['release_name'], tarball))
514
515 for section in config.sections():
516 if 'alias_of' in config[section]:
517 exprs[section] = exprs[str(config[section]['alias_of'])]
518
519 command = [
520 'nix-env',
521 '--profile',
522 '/nix/var/nix/profiles/per-user/%s/channels' %
523 getpass.getuser(),
524 '--show-trace',
525 '--file',
526 '<nix/unpack-channel.nix>',
527 '--install',
528 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
529 if args.dry_run:
530 print(' '.join(map(shlex.quote, command)))
531 else:
532 v.status('Installing channels with nix-env')
533 process = subprocess.run(command)
534 v.result(process.returncode == 0)
535
536
537 def main() -> None:
538 parser = argparse.ArgumentParser(prog='pinch')
539 subparsers = parser.add_subparsers(dest='mode', required=True)
540 parser_pin = subparsers.add_parser('pin')
541 parser_pin.add_argument('channels_file', type=str)
542 parser_pin.add_argument('channels', type=str, nargs='*')
543 parser_pin.set_defaults(func=pin)
544 parser_update = subparsers.add_parser('update')
545 parser_update.add_argument('--dry-run', action='store_true')
546 parser_update.add_argument('channels_file', type=str)
547 parser_update.set_defaults(func=update)
548 args = parser.parse_args()
549 args.func(args)
550
551
552 main()