]> git.scottworley.com Git - pinch/blob - pinch.py
26982260cc62c5f485175c6ee77e84868909a0a5
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tempfile
15 import types
16 import urllib.parse
17 import urllib.request
18 import xml.dom.minidom
19
20 from typing import (
21 Dict,
22 Iterable,
23 List,
24 NewType,
25 Tuple,
26 )
27
28 # Use xdg module when it's less painful to have as a dependency
29
30
31 class XDG(types.SimpleNamespace):
32 XDG_CACHE_HOME: str
33
34
35 xdg = XDG(
36 XDG_CACHE_HOME=os.getenv(
37 'XDG_CACHE_HOME',
38 os.path.expanduser('~/.cache')))
39
40
41 class VerificationError(Exception):
42 pass
43
44
45 class Verification:
46
47 def __init__(self) -> None:
48 self.line_length = 0
49
50 def status(self, s: str) -> None:
51 print(s, end=' ', file=sys.stderr, flush=True)
52 self.line_length += 1 + len(s) # Unicode??
53
54 @staticmethod
55 def _color(s: str, c: int) -> str:
56 return '\033[%2dm%s\033[00m' % (c, s)
57
58 def result(self, r: bool) -> None:
59 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
60 length = len(message)
61 cols = shutil.get_terminal_size().columns or 80
62 pad = (cols - (self.line_length + length)) % cols
63 print(' ' * pad + self._color(message, color), file=sys.stderr)
64 self.line_length = 0
65 if not r:
66 raise VerificationError()
67
68 def check(self, s: str, r: bool) -> None:
69 self.status(s)
70 self.result(r)
71
72 def ok(self) -> None:
73 self.result(True)
74
75
76 Digest16 = NewType('Digest16', str)
77 Digest32 = NewType('Digest32', str)
78
79
80 class ChannelTableEntry(types.SimpleNamespace):
81 absolute_url: str
82 digest: Digest16
83 file: str
84 size: int
85 url: str
86
87
88 class SearchPath(types.SimpleNamespace):
89 release_name: str
90
91
92 class Channel(SearchPath):
93 alias_of: str
94 channel_html: bytes
95 channel_url: str
96 forwarded_url: str
97 git_ref: str
98 git_repo: str
99 git_revision: str
100 old_git_revision: str
101 table: Dict[str, ChannelTableEntry]
102
103 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
104 if hasattr(self, 'alias_of'):
105 assert not hasattr(self, 'git_repo')
106 return
107
108 if hasattr(self, 'git_revision'):
109 self.old_git_revision = self.git_revision
110 del self.git_revision
111
112 if 'channel_url' in conf:
113 pin_channel(v, self)
114 conf['release_name'] = self.release_name
115 conf['tarball_url'] = self.table['nixexprs.tar.xz'].absolute_url
116 conf['tarball_sha256'] = self.table['nixexprs.tar.xz'].digest
117 else:
118 git_fetch(v, self)
119 conf['release_name'] = git_revision_name(v, self)
120 conf['git_revision'] = self.git_revision
121
122 def fetch(self, v: Verification, section: str,
123 conf: configparser.SectionProxy) -> str:
124 if 'git_repo' not in conf or 'release_name' not in conf:
125 raise Exception(
126 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
127 section)
128
129 if 'channel_url' in conf:
130 return fetch_with_nix_prefetch_url(
131 v, conf['tarball_url'], Digest16(
132 conf['tarball_sha256']))
133
134 ensure_git_rev_available(v, self)
135 return git_get_tarball(v, self)
136
137
138 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
139
140 def throw(error: OSError) -> None:
141 raise error
142
143 def join(x: str, y: str) -> str:
144 return y if x == '.' else os.path.join(x, y)
145
146 def recursive_files(d: str) -> Iterable[str]:
147 all_files: List[str] = []
148 for path, dirs, files in os.walk(d, onerror=throw):
149 rel = os.path.relpath(path, start=d)
150 all_files.extend(join(rel, f) for f in files)
151 for dir_or_link in dirs:
152 if os.path.islink(join(path, dir_or_link)):
153 all_files.append(join(rel, dir_or_link))
154 return all_files
155
156 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
157 return (f for f in files if not f.startswith('.git/'))
158
159 files = functools.reduce(
160 operator.or_, (set(
161 exclude_dot_git(
162 recursive_files(x))) for x in [a, b]))
163 return filecmp.cmpfiles(a, b, files, shallow=False)
164
165
166 def fetch(v: Verification, channel: Channel) -> None:
167 v.status('Fetching channel')
168 request = urllib.request.urlopen(channel.channel_url, timeout=10)
169 channel.channel_html = request.read()
170 channel.forwarded_url = request.geturl()
171 v.result(request.status == 200) # type: ignore # (for old mypy)
172 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
173
174
175 def parse_channel(v: Verification, channel: Channel) -> None:
176 v.status('Parsing channel description as XML')
177 d = xml.dom.minidom.parseString(channel.channel_html)
178 v.ok()
179
180 v.status('Extracting release name:')
181 title_name = d.getElementsByTagName(
182 'title')[0].firstChild.nodeValue.split()[2]
183 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
184 v.status(title_name)
185 v.result(title_name == h1_name)
186 channel.release_name = title_name
187
188 v.status('Extracting git commit:')
189 git_commit_node = d.getElementsByTagName('tt')[0]
190 channel.git_revision = git_commit_node.firstChild.nodeValue
191 v.status(channel.git_revision)
192 v.ok()
193 v.status('Verifying git commit label')
194 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
195
196 v.status('Parsing table')
197 channel.table = {}
198 for row in d.getElementsByTagName('tr')[1:]:
199 name = row.childNodes[0].firstChild.firstChild.nodeValue
200 url = row.childNodes[0].firstChild.getAttribute('href')
201 size = int(row.childNodes[1].firstChild.nodeValue)
202 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
203 channel.table[name] = ChannelTableEntry(
204 url=url, digest=digest, size=size)
205 v.ok()
206
207
208 def digest_string(s: bytes) -> Digest16:
209 return Digest16(hashlib.sha256(s).hexdigest())
210
211
212 def digest_file(filename: str) -> Digest16:
213 hasher = hashlib.sha256()
214 with open(filename, 'rb') as f:
215 # pylint: disable=cell-var-from-loop
216 for block in iter(lambda: f.read(4096), b''):
217 hasher.update(block)
218 return Digest16(hasher.hexdigest())
219
220
221 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
222 v.status('Converting digest to base16')
223 process = subprocess.run(
224 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
225 v.result(process.returncode == 0)
226 return Digest16(process.stdout.decode().strip())
227
228
229 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
230 v.status('Converting digest to base32')
231 process = subprocess.run(
232 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
233 v.result(process.returncode == 0)
234 return Digest32(process.stdout.decode().strip())
235
236
237 def fetch_with_nix_prefetch_url(
238 v: Verification,
239 url: str,
240 digest: Digest16) -> str:
241 v.status('Fetching %s' % url)
242 process = subprocess.run(
243 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
244 v.result(process.returncode == 0)
245 prefetch_digest, path, empty = process.stdout.decode().split('\n')
246 assert empty == ''
247 v.check("Verifying nix-prefetch-url's digest",
248 to_Digest16(v, Digest32(prefetch_digest)) == digest)
249 v.status("Verifying file digest")
250 file_digest = digest_file(path)
251 v.result(file_digest == digest)
252 return path # type: ignore # (for old mypy)
253
254
255 def fetch_resources(v: Verification, channel: Channel) -> None:
256 for resource in ['git-revision', 'nixexprs.tar.xz']:
257 fields = channel.table[resource]
258 fields.absolute_url = urllib.parse.urljoin(
259 channel.forwarded_url, fields.url)
260 fields.file = fetch_with_nix_prefetch_url(
261 v, fields.absolute_url, fields.digest)
262 v.status('Verifying git commit on main page matches git commit in table')
263 v.result(
264 open(
265 channel.table['git-revision'].file).read(999) == channel.git_revision)
266
267
268 def git_cachedir(git_repo: str) -> str:
269 return os.path.join(
270 xdg.XDG_CACHE_HOME,
271 'pinch/git',
272 digest_string(git_repo.encode()))
273
274
275 def tarball_cache_file(channel: Channel) -> str:
276 return os.path.join(
277 xdg.XDG_CACHE_HOME,
278 'pinch/git-tarball',
279 '%s-%s-%s' %
280 (digest_string(channel.git_repo.encode()),
281 channel.git_revision,
282 channel.release_name))
283
284
285 def verify_git_ancestry(v: Verification, channel: Channel) -> None:
286 cachedir = git_cachedir(channel.git_repo)
287 v.status('Verifying rev is an ancestor of ref')
288 process = subprocess.run(['git',
289 '-C',
290 cachedir,
291 'merge-base',
292 '--is-ancestor',
293 channel.git_revision,
294 channel.git_ref])
295 v.result(process.returncode == 0)
296
297 if hasattr(channel, 'old_git_revision'):
298 v.status(
299 'Verifying rev is an ancestor of previous rev %s' %
300 channel.old_git_revision)
301 process = subprocess.run(['git',
302 '-C',
303 cachedir,
304 'merge-base',
305 '--is-ancestor',
306 channel.old_git_revision,
307 channel.git_revision])
308 v.result(process.returncode == 0)
309
310
311 def git_fetch(v: Verification, channel: Channel) -> None:
312 # It would be nice if we could share the nix git cache, but as of the time
313 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
314 # yet), and trying to straddle them both is too far into nix implementation
315 # details for my comfort. So we re-implement here half of nix.fetchGit.
316 # :(
317
318 cachedir = git_cachedir(channel.git_repo)
319 if not os.path.exists(cachedir):
320 v.status("Initializing git repo")
321 process = subprocess.run(
322 ['git', 'init', '--bare', cachedir])
323 v.result(process.returncode == 0)
324
325 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
326 # We don't use --force here because we want to abort and freak out if forced
327 # updates are happening.
328 process = subprocess.run(['git',
329 '-C',
330 cachedir,
331 'fetch',
332 channel.git_repo,
333 '%s:%s' % (channel.git_ref,
334 channel.git_ref)])
335 v.result(process.returncode == 0)
336
337 if hasattr(channel, 'git_revision'):
338 v.status('Verifying that fetch retrieved this rev')
339 process = subprocess.run(
340 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
341 v.result(process.returncode == 0)
342 else:
343 channel.git_revision = open(
344 os.path.join(
345 cachedir,
346 'refs',
347 'heads',
348 channel.git_ref)).read(999).strip()
349
350 verify_git_ancestry(v, channel)
351
352
353 def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
354 cachedir = git_cachedir(channel.git_repo)
355 if os.path.exists(cachedir):
356 v.status('Checking if we already have this rev:')
357 process = subprocess.run(
358 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
359 if process.returncode == 0:
360 v.status('yes')
361 if process.returncode == 1:
362 v.status('no')
363 v.result(process.returncode == 0 or process.returncode == 1)
364 if process.returncode == 0:
365 verify_git_ancestry(v, channel)
366 return
367 git_fetch(v, channel)
368
369
370 def compare_tarball_and_git(
371 v: Verification,
372 channel: Channel,
373 channel_contents: str,
374 git_contents: str) -> None:
375 v.status('Comparing channel tarball with git checkout')
376 match, mismatch, errors = compare(os.path.join(
377 channel_contents, channel.release_name), git_contents)
378 v.ok()
379 v.check('%d files match' % len(match), len(match) > 0)
380 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
381 expected_errors = [
382 '.git-revision',
383 '.version-suffix',
384 'nixpkgs',
385 'programs.sqlite',
386 'svn-revision']
387 benign_errors = []
388 for ee in expected_errors:
389 if ee in errors:
390 errors.remove(ee)
391 benign_errors.append(ee)
392 v.check(
393 '%d unexpected incomparable files' %
394 len(errors),
395 len(errors) == 0)
396 v.check(
397 '(%d of %d expected incomparable files)' %
398 (len(benign_errors),
399 len(expected_errors)),
400 len(benign_errors) == len(expected_errors))
401
402
403 def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
404 v.status('Extracting tarball %s' %
405 channel.table['nixexprs.tar.xz'].file)
406 shutil.unpack_archive(
407 channel.table['nixexprs.tar.xz'].file,
408 dest)
409 v.ok()
410
411
412 def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
413 v.status('Checking out corresponding git revision')
414 git = subprocess.Popen(['git',
415 '-C',
416 git_cachedir(channel.git_repo),
417 'archive',
418 channel.git_revision],
419 stdout=subprocess.PIPE)
420 tar = subprocess.Popen(
421 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
422 if git.stdout:
423 git.stdout.close()
424 tar.wait()
425 git.wait()
426 v.result(git.returncode == 0 and tar.returncode == 0)
427
428
429 def git_get_tarball(v: Verification, channel: Channel) -> str:
430 cache_file = tarball_cache_file(channel)
431 if os.path.exists(cache_file):
432 cached_tarball = open(cache_file).read(9999)
433 if os.path.exists(cached_tarball):
434 return cached_tarball
435
436 with tempfile.TemporaryDirectory() as output_dir:
437 output_filename = os.path.join(
438 output_dir, channel.release_name + '.tar.xz')
439 with open(output_filename, 'w') as output_file:
440 v.status(
441 'Generating tarball for git revision %s' %
442 channel.git_revision)
443 git = subprocess.Popen(['git',
444 '-C',
445 git_cachedir(channel.git_repo),
446 'archive',
447 '--prefix=%s/' % channel.release_name,
448 channel.git_revision],
449 stdout=subprocess.PIPE)
450 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
451 xz.wait()
452 git.wait()
453 v.result(git.returncode == 0 and xz.returncode == 0)
454
455 v.status('Putting tarball in Nix store')
456 process = subprocess.run(
457 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
458 v.result(process.returncode == 0)
459 store_tarball = process.stdout.decode().strip()
460
461 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
462 open(cache_file, 'w').write(store_tarball)
463 return store_tarball # type: ignore # (for old mypy)
464
465
466 def check_channel_metadata(
467 v: Verification,
468 channel: Channel,
469 channel_contents: str) -> None:
470 v.status('Verifying git commit in channel tarball')
471 v.result(
472 open(
473 os.path.join(
474 channel_contents,
475 channel.release_name,
476 '.git-revision')).read(999) == channel.git_revision)
477
478 v.status(
479 'Verifying version-suffix is a suffix of release name %s:' %
480 channel.release_name)
481 version_suffix = open(
482 os.path.join(
483 channel_contents,
484 channel.release_name,
485 '.version-suffix')).read(999)
486 v.status(version_suffix)
487 v.result(channel.release_name.endswith(version_suffix))
488
489
490 def check_channel_contents(v: Verification, channel: Channel) -> None:
491 with tempfile.TemporaryDirectory() as channel_contents, \
492 tempfile.TemporaryDirectory() as git_contents:
493
494 extract_tarball(v, channel, channel_contents)
495 check_channel_metadata(v, channel, channel_contents)
496
497 git_checkout(v, channel, git_contents)
498
499 compare_tarball_and_git(v, channel, channel_contents, git_contents)
500
501 v.status('Removing temporary directories')
502 v.ok()
503
504
505 def pin_channel(v: Verification, channel: Channel) -> None:
506 fetch(v, channel)
507 parse_channel(v, channel)
508 fetch_resources(v, channel)
509 ensure_git_rev_available(v, channel)
510 check_channel_contents(v, channel)
511
512
513 def git_revision_name(v: Verification, channel: Channel) -> str:
514 v.status('Getting commit date')
515 process = subprocess.run(['git',
516 '-C',
517 git_cachedir(channel.git_repo),
518 'log',
519 '-n1',
520 '--format=%ct-%h',
521 '--abbrev=11',
522 '--no-show-signature',
523 channel.git_revision],
524 stdout=subprocess.PIPE)
525 v.result(process.returncode == 0 and process.stdout != b'')
526 return '%s-%s' % (os.path.basename(channel.git_repo),
527 process.stdout.decode().strip())
528
529
530 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
531 return Channel(**dict(conf.items()))
532
533
534 def read_config(filename: str) -> configparser.ConfigParser:
535 config = configparser.ConfigParser()
536 config.read_file(open(filename), filename)
537 return config
538
539
540 def read_config_files(
541 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
542 merged_config: Dict[str, configparser.SectionProxy] = {}
543 for file in filenames:
544 config = read_config(file)
545 for section in config.sections():
546 if section in merged_config:
547 raise Exception('Duplicate channel "%s"' % section)
548 merged_config[section] = config[section]
549 return merged_config
550
551
552 def pin(args: argparse.Namespace) -> None:
553 v = Verification()
554 config = read_config(args.channels_file)
555 for section in config.sections():
556 if args.channels and section not in args.channels:
557 continue
558
559 sp = read_search_path(config[section])
560
561 sp.pin(v, config[section])
562
563 with open(args.channels_file, 'w') as configfile:
564 config.write(configfile)
565
566
567 def update(args: argparse.Namespace) -> None:
568 v = Verification()
569 exprs: Dict[str, str] = {}
570 config = read_config_files(args.channels_file)
571 for section in config:
572 if 'alias_of' in config[section]:
573 assert 'git_repo' not in config[section]
574 continue
575 sp = read_search_path(config[section])
576 tarball = sp.fetch(v, section, config[section])
577 exprs[section] = (
578 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
579 (config[section]['release_name'], tarball))
580
581 for section in config:
582 if 'alias_of' in config[section]:
583 exprs[section] = exprs[str(config[section]['alias_of'])]
584
585 command = [
586 'nix-env',
587 '--profile',
588 '/nix/var/nix/profiles/per-user/%s/channels' %
589 getpass.getuser(),
590 '--show-trace',
591 '--file',
592 '<nix/unpack-channel.nix>',
593 '--install',
594 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
595 if args.dry_run:
596 print(' '.join(map(shlex.quote, command)))
597 else:
598 v.status('Installing channels with nix-env')
599 process = subprocess.run(command)
600 v.result(process.returncode == 0)
601
602
603 def main() -> None:
604 parser = argparse.ArgumentParser(prog='pinch')
605 subparsers = parser.add_subparsers(dest='mode', required=True)
606 parser_pin = subparsers.add_parser('pin')
607 parser_pin.add_argument('channels_file', type=str)
608 parser_pin.add_argument('channels', type=str, nargs='*')
609 parser_pin.set_defaults(func=pin)
610 parser_update = subparsers.add_parser('update')
611 parser_update.add_argument('--dry-run', action='store_true')
612 parser_update.add_argument('channels_file', type=str, nargs='+')
613 parser_update.set_defaults(func=update)
614 args = parser.parse_args()
615 args.func(args)
616
617
618 if __name__ == '__main__':
619 main()