]> git.scottworley.com Git - pinch/blob - pinch.py
Introduce SearchPath type
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tempfile
15 import types
16 import urllib.parse
17 import urllib.request
18 import xml.dom.minidom
19
20 from typing import (
21 Dict,
22 Iterable,
23 List,
24 NewType,
25 Tuple,
26 )
27
28 # Use xdg module when it's less painful to have as a dependency
29
30
31 class XDG(types.SimpleNamespace):
32 XDG_CACHE_HOME: str
33
34
35 xdg = XDG(
36 XDG_CACHE_HOME=os.getenv(
37 'XDG_CACHE_HOME',
38 os.path.expanduser('~/.cache')))
39
40
41 Digest16 = NewType('Digest16', str)
42 Digest32 = NewType('Digest32', str)
43
44
45 class ChannelTableEntry(types.SimpleNamespace):
46 absolute_url: str
47 digest: Digest16
48 file: str
49 size: int
50 url: str
51
52
53 class SearchPath(types.SimpleNamespace):
54 release_name: str
55
56
57 class Channel(SearchPath):
58 alias_of: str
59 channel_html: bytes
60 channel_url: str
61 forwarded_url: str
62 git_ref: str
63 git_repo: str
64 git_revision: str
65 old_git_revision: str
66 table: Dict[str, ChannelTableEntry]
67
68
69 class VerificationError(Exception):
70 pass
71
72
73 class Verification:
74
75 def __init__(self) -> None:
76 self.line_length = 0
77
78 def status(self, s: str) -> None:
79 print(s, end=' ', file=sys.stderr, flush=True)
80 self.line_length += 1 + len(s) # Unicode??
81
82 @staticmethod
83 def _color(s: str, c: int) -> str:
84 return '\033[%2dm%s\033[00m' % (c, s)
85
86 def result(self, r: bool) -> None:
87 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
88 length = len(message)
89 cols = shutil.get_terminal_size().columns or 80
90 pad = (cols - (self.line_length + length)) % cols
91 print(' ' * pad + self._color(message, color), file=sys.stderr)
92 self.line_length = 0
93 if not r:
94 raise VerificationError()
95
96 def check(self, s: str, r: bool) -> None:
97 self.status(s)
98 self.result(r)
99
100 def ok(self) -> None:
101 self.result(True)
102
103
104 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
105
106 def throw(error: OSError) -> None:
107 raise error
108
109 def join(x: str, y: str) -> str:
110 return y if x == '.' else os.path.join(x, y)
111
112 def recursive_files(d: str) -> Iterable[str]:
113 all_files: List[str] = []
114 for path, dirs, files in os.walk(d, onerror=throw):
115 rel = os.path.relpath(path, start=d)
116 all_files.extend(join(rel, f) for f in files)
117 for dir_or_link in dirs:
118 if os.path.islink(join(path, dir_or_link)):
119 all_files.append(join(rel, dir_or_link))
120 return all_files
121
122 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
123 return (f for f in files if not f.startswith('.git/'))
124
125 files = functools.reduce(
126 operator.or_, (set(
127 exclude_dot_git(
128 recursive_files(x))) for x in [a, b]))
129 return filecmp.cmpfiles(a, b, files, shallow=False)
130
131
132 def fetch(v: Verification, channel: Channel) -> None:
133 v.status('Fetching channel')
134 request = urllib.request.urlopen(channel.channel_url, timeout=10)
135 channel.channel_html = request.read()
136 channel.forwarded_url = request.geturl()
137 v.result(request.status == 200) # type: ignore # (for old mypy)
138 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
139
140
141 def parse_channel(v: Verification, channel: Channel) -> None:
142 v.status('Parsing channel description as XML')
143 d = xml.dom.minidom.parseString(channel.channel_html)
144 v.ok()
145
146 v.status('Extracting release name:')
147 title_name = d.getElementsByTagName(
148 'title')[0].firstChild.nodeValue.split()[2]
149 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
150 v.status(title_name)
151 v.result(title_name == h1_name)
152 channel.release_name = title_name
153
154 v.status('Extracting git commit:')
155 git_commit_node = d.getElementsByTagName('tt')[0]
156 channel.git_revision = git_commit_node.firstChild.nodeValue
157 v.status(channel.git_revision)
158 v.ok()
159 v.status('Verifying git commit label')
160 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
161
162 v.status('Parsing table')
163 channel.table = {}
164 for row in d.getElementsByTagName('tr')[1:]:
165 name = row.childNodes[0].firstChild.firstChild.nodeValue
166 url = row.childNodes[0].firstChild.getAttribute('href')
167 size = int(row.childNodes[1].firstChild.nodeValue)
168 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
169 channel.table[name] = ChannelTableEntry(
170 url=url, digest=digest, size=size)
171 v.ok()
172
173
174 def digest_string(s: bytes) -> Digest16:
175 return Digest16(hashlib.sha256(s).hexdigest())
176
177
178 def digest_file(filename: str) -> Digest16:
179 hasher = hashlib.sha256()
180 with open(filename, 'rb') as f:
181 # pylint: disable=cell-var-from-loop
182 for block in iter(lambda: f.read(4096), b''):
183 hasher.update(block)
184 return Digest16(hasher.hexdigest())
185
186
187 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
188 v.status('Converting digest to base16')
189 process = subprocess.run(
190 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
191 v.result(process.returncode == 0)
192 return Digest16(process.stdout.decode().strip())
193
194
195 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
196 v.status('Converting digest to base32')
197 process = subprocess.run(
198 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
199 v.result(process.returncode == 0)
200 return Digest32(process.stdout.decode().strip())
201
202
203 def fetch_with_nix_prefetch_url(
204 v: Verification,
205 url: str,
206 digest: Digest16) -> str:
207 v.status('Fetching %s' % url)
208 process = subprocess.run(
209 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
210 v.result(process.returncode == 0)
211 prefetch_digest, path, empty = process.stdout.decode().split('\n')
212 assert empty == ''
213 v.check("Verifying nix-prefetch-url's digest",
214 to_Digest16(v, Digest32(prefetch_digest)) == digest)
215 v.status("Verifying file digest")
216 file_digest = digest_file(path)
217 v.result(file_digest == digest)
218 return path # type: ignore # (for old mypy)
219
220
221 def fetch_resources(v: Verification, channel: Channel) -> None:
222 for resource in ['git-revision', 'nixexprs.tar.xz']:
223 fields = channel.table[resource]
224 fields.absolute_url = urllib.parse.urljoin(
225 channel.forwarded_url, fields.url)
226 fields.file = fetch_with_nix_prefetch_url(
227 v, fields.absolute_url, fields.digest)
228 v.status('Verifying git commit on main page matches git commit in table')
229 v.result(
230 open(
231 channel.table['git-revision'].file).read(999) == channel.git_revision)
232
233
234 def git_cachedir(git_repo: str) -> str:
235 return os.path.join(
236 xdg.XDG_CACHE_HOME,
237 'pinch/git',
238 digest_string(git_repo.encode()))
239
240
241 def tarball_cache_file(channel: Channel) -> str:
242 return os.path.join(
243 xdg.XDG_CACHE_HOME,
244 'pinch/git-tarball',
245 '%s-%s-%s' %
246 (digest_string(channel.git_repo.encode()),
247 channel.git_revision,
248 channel.release_name))
249
250
251 def verify_git_ancestry(v: Verification, channel: Channel) -> None:
252 cachedir = git_cachedir(channel.git_repo)
253 v.status('Verifying rev is an ancestor of ref')
254 process = subprocess.run(['git',
255 '-C',
256 cachedir,
257 'merge-base',
258 '--is-ancestor',
259 channel.git_revision,
260 channel.git_ref])
261 v.result(process.returncode == 0)
262
263 if hasattr(channel, 'old_git_revision'):
264 v.status(
265 'Verifying rev is an ancestor of previous rev %s' %
266 channel.old_git_revision)
267 process = subprocess.run(['git',
268 '-C',
269 cachedir,
270 'merge-base',
271 '--is-ancestor',
272 channel.old_git_revision,
273 channel.git_revision])
274 v.result(process.returncode == 0)
275
276
277 def git_fetch(v: Verification, channel: Channel) -> None:
278 # It would be nice if we could share the nix git cache, but as of the time
279 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
280 # yet), and trying to straddle them both is too far into nix implementation
281 # details for my comfort. So we re-implement here half of nix.fetchGit.
282 # :(
283
284 cachedir = git_cachedir(channel.git_repo)
285 if not os.path.exists(cachedir):
286 v.status("Initializing git repo")
287 process = subprocess.run(
288 ['git', 'init', '--bare', cachedir])
289 v.result(process.returncode == 0)
290
291 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
292 # We don't use --force here because we want to abort and freak out if forced
293 # updates are happening.
294 process = subprocess.run(['git',
295 '-C',
296 cachedir,
297 'fetch',
298 channel.git_repo,
299 '%s:%s' % (channel.git_ref,
300 channel.git_ref)])
301 v.result(process.returncode == 0)
302
303 if hasattr(channel, 'git_revision'):
304 v.status('Verifying that fetch retrieved this rev')
305 process = subprocess.run(
306 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
307 v.result(process.returncode == 0)
308 else:
309 channel.git_revision = open(
310 os.path.join(
311 cachedir,
312 'refs',
313 'heads',
314 channel.git_ref)).read(999).strip()
315
316 verify_git_ancestry(v, channel)
317
318
319 def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
320 cachedir = git_cachedir(channel.git_repo)
321 if os.path.exists(cachedir):
322 v.status('Checking if we already have this rev:')
323 process = subprocess.run(
324 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
325 if process.returncode == 0:
326 v.status('yes')
327 if process.returncode == 1:
328 v.status('no')
329 v.result(process.returncode == 0 or process.returncode == 1)
330 if process.returncode == 0:
331 verify_git_ancestry(v, channel)
332 return
333 git_fetch(v, channel)
334
335
336 def compare_tarball_and_git(
337 v: Verification,
338 channel: Channel,
339 channel_contents: str,
340 git_contents: str) -> None:
341 v.status('Comparing channel tarball with git checkout')
342 match, mismatch, errors = compare(os.path.join(
343 channel_contents, channel.release_name), git_contents)
344 v.ok()
345 v.check('%d files match' % len(match), len(match) > 0)
346 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
347 expected_errors = [
348 '.git-revision',
349 '.version-suffix',
350 'nixpkgs',
351 'programs.sqlite',
352 'svn-revision']
353 benign_errors = []
354 for ee in expected_errors:
355 if ee in errors:
356 errors.remove(ee)
357 benign_errors.append(ee)
358 v.check(
359 '%d unexpected incomparable files' %
360 len(errors),
361 len(errors) == 0)
362 v.check(
363 '(%d of %d expected incomparable files)' %
364 (len(benign_errors),
365 len(expected_errors)),
366 len(benign_errors) == len(expected_errors))
367
368
369 def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
370 v.status('Extracting tarball %s' %
371 channel.table['nixexprs.tar.xz'].file)
372 shutil.unpack_archive(
373 channel.table['nixexprs.tar.xz'].file,
374 dest)
375 v.ok()
376
377
378 def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
379 v.status('Checking out corresponding git revision')
380 git = subprocess.Popen(['git',
381 '-C',
382 git_cachedir(channel.git_repo),
383 'archive',
384 channel.git_revision],
385 stdout=subprocess.PIPE)
386 tar = subprocess.Popen(
387 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
388 if git.stdout:
389 git.stdout.close()
390 tar.wait()
391 git.wait()
392 v.result(git.returncode == 0 and tar.returncode == 0)
393
394
395 def git_get_tarball(v: Verification, channel: Channel) -> str:
396 cache_file = tarball_cache_file(channel)
397 if os.path.exists(cache_file):
398 cached_tarball = open(cache_file).read(9999)
399 if os.path.exists(cached_tarball):
400 return cached_tarball
401
402 with tempfile.TemporaryDirectory() as output_dir:
403 output_filename = os.path.join(
404 output_dir, channel.release_name + '.tar.xz')
405 with open(output_filename, 'w') as output_file:
406 v.status(
407 'Generating tarball for git revision %s' %
408 channel.git_revision)
409 git = subprocess.Popen(['git',
410 '-C',
411 git_cachedir(channel.git_repo),
412 'archive',
413 '--prefix=%s/' % channel.release_name,
414 channel.git_revision],
415 stdout=subprocess.PIPE)
416 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
417 xz.wait()
418 git.wait()
419 v.result(git.returncode == 0 and xz.returncode == 0)
420
421 v.status('Putting tarball in Nix store')
422 process = subprocess.run(
423 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
424 v.result(process.returncode == 0)
425 store_tarball = process.stdout.decode().strip()
426
427 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
428 open(cache_file, 'w').write(store_tarball)
429 return store_tarball # type: ignore # (for old mypy)
430
431
432 def check_channel_metadata(
433 v: Verification,
434 channel: Channel,
435 channel_contents: str) -> None:
436 v.status('Verifying git commit in channel tarball')
437 v.result(
438 open(
439 os.path.join(
440 channel_contents,
441 channel.release_name,
442 '.git-revision')).read(999) == channel.git_revision)
443
444 v.status(
445 'Verifying version-suffix is a suffix of release name %s:' %
446 channel.release_name)
447 version_suffix = open(
448 os.path.join(
449 channel_contents,
450 channel.release_name,
451 '.version-suffix')).read(999)
452 v.status(version_suffix)
453 v.result(channel.release_name.endswith(version_suffix))
454
455
456 def check_channel_contents(v: Verification, channel: Channel) -> None:
457 with tempfile.TemporaryDirectory() as channel_contents, \
458 tempfile.TemporaryDirectory() as git_contents:
459
460 extract_tarball(v, channel, channel_contents)
461 check_channel_metadata(v, channel, channel_contents)
462
463 git_checkout(v, channel, git_contents)
464
465 compare_tarball_and_git(v, channel, channel_contents, git_contents)
466
467 v.status('Removing temporary directories')
468 v.ok()
469
470
471 def pin_channel(v: Verification, channel: Channel) -> None:
472 fetch(v, channel)
473 parse_channel(v, channel)
474 fetch_resources(v, channel)
475 ensure_git_rev_available(v, channel)
476 check_channel_contents(v, channel)
477
478
479 def git_revision_name(v: Verification, channel: Channel) -> str:
480 v.status('Getting commit date')
481 process = subprocess.run(['git',
482 '-C',
483 git_cachedir(channel.git_repo),
484 'log',
485 '-n1',
486 '--format=%ct-%h',
487 '--abbrev=11',
488 '--no-show-signature',
489 channel.git_revision],
490 stdout=subprocess.PIPE)
491 v.result(process.returncode == 0 and process.stdout != b'')
492 return '%s-%s' % (os.path.basename(channel.git_repo),
493 process.stdout.decode().strip())
494
495
496 def read_config(filename: str) -> configparser.ConfigParser:
497 config = configparser.ConfigParser()
498 config.read_file(open(filename), filename)
499 return config
500
501
502 def pin(args: argparse.Namespace) -> None:
503 v = Verification()
504 config = read_config(args.channels_file)
505 for section in config.sections():
506 if args.channels and section not in args.channels:
507 continue
508
509 channel = Channel(**dict(config[section].items()))
510
511 if hasattr(channel, 'alias_of'):
512 assert not hasattr(channel, 'git_repo')
513 continue
514
515 if hasattr(channel, 'git_revision'):
516 channel.old_git_revision = channel.git_revision
517 del channel.git_revision
518
519 if 'channel_url' in config[section]:
520 pin_channel(v, channel)
521 config[section]['release_name'] = channel.release_name
522 config[section]['tarball_url'] = channel.table['nixexprs.tar.xz'].absolute_url
523 config[section]['tarball_sha256'] = channel.table['nixexprs.tar.xz'].digest
524 else:
525 git_fetch(v, channel)
526 config[section]['release_name'] = git_revision_name(v, channel)
527 config[section]['git_revision'] = channel.git_revision
528
529 with open(args.channels_file, 'w') as configfile:
530 config.write(configfile)
531
532
533 def fetch_channel(
534 v: Verification,
535 section: str,
536 conf: configparser.SectionProxy) -> str:
537 if 'git_repo' not in conf or 'release_name' not in conf:
538 raise Exception(
539 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
540 section)
541
542 if 'channel_url' in conf:
543 return fetch_with_nix_prefetch_url(
544 v, conf['tarball_url'], Digest16(
545 conf['tarball_sha256']))
546
547 channel = Channel(**dict(conf.items()))
548 ensure_git_rev_available(v, channel)
549 return git_get_tarball(v, channel)
550
551
552 def update(args: argparse.Namespace) -> None:
553 v = Verification()
554 config = configparser.ConfigParser()
555 exprs: Dict[str, str] = {}
556 configs = [read_config(filename) for filename in args.channels_file]
557 for config in configs:
558 for section in config.sections():
559 if 'alias_of' in config[section]:
560 assert 'git_repo' not in config[section]
561 continue
562 tarball = fetch_channel(v, section, config[section])
563 if section in exprs:
564 raise Exception('Duplicate channel "%s"' % section)
565 exprs[section] = (
566 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
567 (config[section]['release_name'], tarball))
568
569 for config in configs:
570 for section in config.sections():
571 if 'alias_of' in config[section]:
572 if section in exprs:
573 raise Exception('Duplicate channel "%s"' % section)
574 exprs[section] = exprs[str(config[section]['alias_of'])]
575
576 command = [
577 'nix-env',
578 '--profile',
579 '/nix/var/nix/profiles/per-user/%s/channels' %
580 getpass.getuser(),
581 '--show-trace',
582 '--file',
583 '<nix/unpack-channel.nix>',
584 '--install',
585 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
586 if args.dry_run:
587 print(' '.join(map(shlex.quote, command)))
588 else:
589 v.status('Installing channels with nix-env')
590 process = subprocess.run(command)
591 v.result(process.returncode == 0)
592
593
594 def main() -> None:
595 parser = argparse.ArgumentParser(prog='pinch')
596 subparsers = parser.add_subparsers(dest='mode', required=True)
597 parser_pin = subparsers.add_parser('pin')
598 parser_pin.add_argument('channels_file', type=str)
599 parser_pin.add_argument('channels', type=str, nargs='*')
600 parser_pin.set_defaults(func=pin)
601 parser_update = subparsers.add_parser('update')
602 parser_update.add_argument('--dry-run', action='store_true')
603 parser_update.add_argument('channels_file', type=str, nargs='+')
604 parser_update.set_defaults(func=update)
605 args = parser.parse_args()
606 args.func(args)
607
608
609 if __name__ == '__main__':
610 main()