]> git.scottworley.com Git - pinch/blob - pinch.py
34c63f1d4b418380142baf28334507485f3f372f
[pinch] / pinch.py
1 from abc import ABC, abstractmethod
2 import argparse
3 import configparser
4 import filecmp
5 import functools
6 import getpass
7 import hashlib
8 import operator
9 import os
10 import os.path
11 import shlex
12 import shutil
13 import subprocess
14 import sys
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Dict,
23 Iterable,
24 List,
25 NewType,
26 Tuple,
27 )
28
29 # Use xdg module when it's less painful to have as a dependency
30
31
32 class XDG(types.SimpleNamespace):
33 XDG_CACHE_HOME: str
34
35
36 xdg = XDG(
37 XDG_CACHE_HOME=os.getenv(
38 'XDG_CACHE_HOME',
39 os.path.expanduser('~/.cache')))
40
41
42 class VerificationError(Exception):
43 pass
44
45
46 class Verification:
47
48 def __init__(self) -> None:
49 self.line_length = 0
50
51 def status(self, s: str) -> None:
52 print(s, end=' ', file=sys.stderr, flush=True)
53 self.line_length += 1 + len(s) # Unicode??
54
55 @staticmethod
56 def _color(s: str, c: int) -> str:
57 return '\033[%2dm%s\033[00m' % (c, s)
58
59 def result(self, r: bool) -> None:
60 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
61 length = len(message)
62 cols = shutil.get_terminal_size().columns or 80
63 pad = (cols - (self.line_length + length)) % cols
64 print(' ' * pad + self._color(message, color), file=sys.stderr)
65 self.line_length = 0
66 if not r:
67 raise VerificationError()
68
69 def check(self, s: str, r: bool) -> None:
70 self.status(s)
71 self.result(r)
72
73 def ok(self) -> None:
74 self.result(True)
75
76
77 Digest16 = NewType('Digest16', str)
78 Digest32 = NewType('Digest32', str)
79
80
81 class ChannelTableEntry(types.SimpleNamespace):
82 absolute_url: str
83 digest: Digest16
84 file: str
85 size: int
86 url: str
87
88
89 class SearchPath(types.SimpleNamespace, ABC):
90 release_name: str
91
92 @abstractmethod
93 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
94 pass
95
96
97 class AliasSearchPath(SearchPath):
98 alias_of: str
99
100 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
101 assert not hasattr(self, 'git_repo')
102
103
104 class Channel(SearchPath):
105 channel_html: bytes
106 channel_url: str
107 forwarded_url: str
108 git_ref: str
109 git_repo: str
110 git_revision: str
111 old_git_revision: str
112 table: Dict[str, ChannelTableEntry]
113
114 def pin(self, v: Verification, conf: configparser.SectionProxy) -> None:
115 if hasattr(self, 'git_revision'):
116 self.old_git_revision = self.git_revision
117 del self.git_revision
118
119 if 'channel_url' in conf:
120 pin_channel(v, self)
121 conf['release_name'] = self.release_name
122 conf['tarball_url'] = self.table['nixexprs.tar.xz'].absolute_url
123 conf['tarball_sha256'] = self.table['nixexprs.tar.xz'].digest
124 else:
125 git_fetch(v, self)
126 conf['release_name'] = git_revision_name(v, self)
127 conf['git_revision'] = self.git_revision
128
129 def fetch(self, v: Verification, section: str,
130 conf: configparser.SectionProxy) -> str:
131 if 'git_repo' not in conf or 'release_name' not in conf:
132 raise Exception(
133 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
134 section)
135
136 if 'channel_url' in conf:
137 return fetch_with_nix_prefetch_url(
138 v, conf['tarball_url'], Digest16(
139 conf['tarball_sha256']))
140
141 ensure_git_rev_available(v, self)
142 return git_get_tarball(v, self)
143
144
145 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
146
147 def throw(error: OSError) -> None:
148 raise error
149
150 def join(x: str, y: str) -> str:
151 return y if x == '.' else os.path.join(x, y)
152
153 def recursive_files(d: str) -> Iterable[str]:
154 all_files: List[str] = []
155 for path, dirs, files in os.walk(d, onerror=throw):
156 rel = os.path.relpath(path, start=d)
157 all_files.extend(join(rel, f) for f in files)
158 for dir_or_link in dirs:
159 if os.path.islink(join(path, dir_or_link)):
160 all_files.append(join(rel, dir_or_link))
161 return all_files
162
163 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
164 return (f for f in files if not f.startswith('.git/'))
165
166 files = functools.reduce(
167 operator.or_, (set(
168 exclude_dot_git(
169 recursive_files(x))) for x in [a, b]))
170 return filecmp.cmpfiles(a, b, files, shallow=False)
171
172
173 def fetch(v: Verification, channel: Channel) -> None:
174 v.status('Fetching channel')
175 request = urllib.request.urlopen(channel.channel_url, timeout=10)
176 channel.channel_html = request.read()
177 channel.forwarded_url = request.geturl()
178 v.result(request.status == 200) # type: ignore # (for old mypy)
179 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
180
181
182 def parse_channel(v: Verification, channel: Channel) -> None:
183 v.status('Parsing channel description as XML')
184 d = xml.dom.minidom.parseString(channel.channel_html)
185 v.ok()
186
187 v.status('Extracting release name:')
188 title_name = d.getElementsByTagName(
189 'title')[0].firstChild.nodeValue.split()[2]
190 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
191 v.status(title_name)
192 v.result(title_name == h1_name)
193 channel.release_name = title_name
194
195 v.status('Extracting git commit:')
196 git_commit_node = d.getElementsByTagName('tt')[0]
197 channel.git_revision = git_commit_node.firstChild.nodeValue
198 v.status(channel.git_revision)
199 v.ok()
200 v.status('Verifying git commit label')
201 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
202
203 v.status('Parsing table')
204 channel.table = {}
205 for row in d.getElementsByTagName('tr')[1:]:
206 name = row.childNodes[0].firstChild.firstChild.nodeValue
207 url = row.childNodes[0].firstChild.getAttribute('href')
208 size = int(row.childNodes[1].firstChild.nodeValue)
209 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
210 channel.table[name] = ChannelTableEntry(
211 url=url, digest=digest, size=size)
212 v.ok()
213
214
215 def digest_string(s: bytes) -> Digest16:
216 return Digest16(hashlib.sha256(s).hexdigest())
217
218
219 def digest_file(filename: str) -> Digest16:
220 hasher = hashlib.sha256()
221 with open(filename, 'rb') as f:
222 # pylint: disable=cell-var-from-loop
223 for block in iter(lambda: f.read(4096), b''):
224 hasher.update(block)
225 return Digest16(hasher.hexdigest())
226
227
228 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
229 v.status('Converting digest to base16')
230 process = subprocess.run(
231 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
232 v.result(process.returncode == 0)
233 return Digest16(process.stdout.decode().strip())
234
235
236 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
237 v.status('Converting digest to base32')
238 process = subprocess.run(
239 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
240 v.result(process.returncode == 0)
241 return Digest32(process.stdout.decode().strip())
242
243
244 def fetch_with_nix_prefetch_url(
245 v: Verification,
246 url: str,
247 digest: Digest16) -> str:
248 v.status('Fetching %s' % url)
249 process = subprocess.run(
250 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
251 v.result(process.returncode == 0)
252 prefetch_digest, path, empty = process.stdout.decode().split('\n')
253 assert empty == ''
254 v.check("Verifying nix-prefetch-url's digest",
255 to_Digest16(v, Digest32(prefetch_digest)) == digest)
256 v.status("Verifying file digest")
257 file_digest = digest_file(path)
258 v.result(file_digest == digest)
259 return path # type: ignore # (for old mypy)
260
261
262 def fetch_resources(v: Verification, channel: Channel) -> None:
263 for resource in ['git-revision', 'nixexprs.tar.xz']:
264 fields = channel.table[resource]
265 fields.absolute_url = urllib.parse.urljoin(
266 channel.forwarded_url, fields.url)
267 fields.file = fetch_with_nix_prefetch_url(
268 v, fields.absolute_url, fields.digest)
269 v.status('Verifying git commit on main page matches git commit in table')
270 v.result(
271 open(
272 channel.table['git-revision'].file).read(999) == channel.git_revision)
273
274
275 def git_cachedir(git_repo: str) -> str:
276 return os.path.join(
277 xdg.XDG_CACHE_HOME,
278 'pinch/git',
279 digest_string(git_repo.encode()))
280
281
282 def tarball_cache_file(channel: Channel) -> str:
283 return os.path.join(
284 xdg.XDG_CACHE_HOME,
285 'pinch/git-tarball',
286 '%s-%s-%s' %
287 (digest_string(channel.git_repo.encode()),
288 channel.git_revision,
289 channel.release_name))
290
291
292 def verify_git_ancestry(v: Verification, channel: Channel) -> None:
293 cachedir = git_cachedir(channel.git_repo)
294 v.status('Verifying rev is an ancestor of ref')
295 process = subprocess.run(['git',
296 '-C',
297 cachedir,
298 'merge-base',
299 '--is-ancestor',
300 channel.git_revision,
301 channel.git_ref])
302 v.result(process.returncode == 0)
303
304 if hasattr(channel, 'old_git_revision'):
305 v.status(
306 'Verifying rev is an ancestor of previous rev %s' %
307 channel.old_git_revision)
308 process = subprocess.run(['git',
309 '-C',
310 cachedir,
311 'merge-base',
312 '--is-ancestor',
313 channel.old_git_revision,
314 channel.git_revision])
315 v.result(process.returncode == 0)
316
317
318 def git_fetch(v: Verification, channel: Channel) -> None:
319 # It would be nice if we could share the nix git cache, but as of the time
320 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
321 # yet), and trying to straddle them both is too far into nix implementation
322 # details for my comfort. So we re-implement here half of nix.fetchGit.
323 # :(
324
325 cachedir = git_cachedir(channel.git_repo)
326 if not os.path.exists(cachedir):
327 v.status("Initializing git repo")
328 process = subprocess.run(
329 ['git', 'init', '--bare', cachedir])
330 v.result(process.returncode == 0)
331
332 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
333 # We don't use --force here because we want to abort and freak out if forced
334 # updates are happening.
335 process = subprocess.run(['git',
336 '-C',
337 cachedir,
338 'fetch',
339 channel.git_repo,
340 '%s:%s' % (channel.git_ref,
341 channel.git_ref)])
342 v.result(process.returncode == 0)
343
344 if hasattr(channel, 'git_revision'):
345 v.status('Verifying that fetch retrieved this rev')
346 process = subprocess.run(
347 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
348 v.result(process.returncode == 0)
349 else:
350 channel.git_revision = open(
351 os.path.join(
352 cachedir,
353 'refs',
354 'heads',
355 channel.git_ref)).read(999).strip()
356
357 verify_git_ancestry(v, channel)
358
359
360 def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
361 cachedir = git_cachedir(channel.git_repo)
362 if os.path.exists(cachedir):
363 v.status('Checking if we already have this rev:')
364 process = subprocess.run(
365 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
366 if process.returncode == 0:
367 v.status('yes')
368 if process.returncode == 1:
369 v.status('no')
370 v.result(process.returncode == 0 or process.returncode == 1)
371 if process.returncode == 0:
372 verify_git_ancestry(v, channel)
373 return
374 git_fetch(v, channel)
375
376
377 def compare_tarball_and_git(
378 v: Verification,
379 channel: Channel,
380 channel_contents: str,
381 git_contents: str) -> None:
382 v.status('Comparing channel tarball with git checkout')
383 match, mismatch, errors = compare(os.path.join(
384 channel_contents, channel.release_name), git_contents)
385 v.ok()
386 v.check('%d files match' % len(match), len(match) > 0)
387 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
388 expected_errors = [
389 '.git-revision',
390 '.version-suffix',
391 'nixpkgs',
392 'programs.sqlite',
393 'svn-revision']
394 benign_errors = []
395 for ee in expected_errors:
396 if ee in errors:
397 errors.remove(ee)
398 benign_errors.append(ee)
399 v.check(
400 '%d unexpected incomparable files' %
401 len(errors),
402 len(errors) == 0)
403 v.check(
404 '(%d of %d expected incomparable files)' %
405 (len(benign_errors),
406 len(expected_errors)),
407 len(benign_errors) == len(expected_errors))
408
409
410 def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
411 v.status('Extracting tarball %s' %
412 channel.table['nixexprs.tar.xz'].file)
413 shutil.unpack_archive(
414 channel.table['nixexprs.tar.xz'].file,
415 dest)
416 v.ok()
417
418
419 def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
420 v.status('Checking out corresponding git revision')
421 git = subprocess.Popen(['git',
422 '-C',
423 git_cachedir(channel.git_repo),
424 'archive',
425 channel.git_revision],
426 stdout=subprocess.PIPE)
427 tar = subprocess.Popen(
428 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
429 if git.stdout:
430 git.stdout.close()
431 tar.wait()
432 git.wait()
433 v.result(git.returncode == 0 and tar.returncode == 0)
434
435
436 def git_get_tarball(v: Verification, channel: Channel) -> str:
437 cache_file = tarball_cache_file(channel)
438 if os.path.exists(cache_file):
439 cached_tarball = open(cache_file).read(9999)
440 if os.path.exists(cached_tarball):
441 return cached_tarball
442
443 with tempfile.TemporaryDirectory() as output_dir:
444 output_filename = os.path.join(
445 output_dir, channel.release_name + '.tar.xz')
446 with open(output_filename, 'w') as output_file:
447 v.status(
448 'Generating tarball for git revision %s' %
449 channel.git_revision)
450 git = subprocess.Popen(['git',
451 '-C',
452 git_cachedir(channel.git_repo),
453 'archive',
454 '--prefix=%s/' % channel.release_name,
455 channel.git_revision],
456 stdout=subprocess.PIPE)
457 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
458 xz.wait()
459 git.wait()
460 v.result(git.returncode == 0 and xz.returncode == 0)
461
462 v.status('Putting tarball in Nix store')
463 process = subprocess.run(
464 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
465 v.result(process.returncode == 0)
466 store_tarball = process.stdout.decode().strip()
467
468 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
469 open(cache_file, 'w').write(store_tarball)
470 return store_tarball # type: ignore # (for old mypy)
471
472
473 def check_channel_metadata(
474 v: Verification,
475 channel: Channel,
476 channel_contents: str) -> None:
477 v.status('Verifying git commit in channel tarball')
478 v.result(
479 open(
480 os.path.join(
481 channel_contents,
482 channel.release_name,
483 '.git-revision')).read(999) == channel.git_revision)
484
485 v.status(
486 'Verifying version-suffix is a suffix of release name %s:' %
487 channel.release_name)
488 version_suffix = open(
489 os.path.join(
490 channel_contents,
491 channel.release_name,
492 '.version-suffix')).read(999)
493 v.status(version_suffix)
494 v.result(channel.release_name.endswith(version_suffix))
495
496
497 def check_channel_contents(v: Verification, channel: Channel) -> None:
498 with tempfile.TemporaryDirectory() as channel_contents, \
499 tempfile.TemporaryDirectory() as git_contents:
500
501 extract_tarball(v, channel, channel_contents)
502 check_channel_metadata(v, channel, channel_contents)
503
504 git_checkout(v, channel, git_contents)
505
506 compare_tarball_and_git(v, channel, channel_contents, git_contents)
507
508 v.status('Removing temporary directories')
509 v.ok()
510
511
512 def pin_channel(v: Verification, channel: Channel) -> None:
513 fetch(v, channel)
514 parse_channel(v, channel)
515 fetch_resources(v, channel)
516 ensure_git_rev_available(v, channel)
517 check_channel_contents(v, channel)
518
519
520 def git_revision_name(v: Verification, channel: Channel) -> str:
521 v.status('Getting commit date')
522 process = subprocess.run(['git',
523 '-C',
524 git_cachedir(channel.git_repo),
525 'log',
526 '-n1',
527 '--format=%ct-%h',
528 '--abbrev=11',
529 '--no-show-signature',
530 channel.git_revision],
531 stdout=subprocess.PIPE)
532 v.result(process.returncode == 0 and process.stdout != b'')
533 return '%s-%s' % (os.path.basename(channel.git_repo),
534 process.stdout.decode().strip())
535
536
537 def read_search_path(conf: configparser.SectionProxy) -> SearchPath:
538 if 'alias_of' in conf:
539 return AliasSearchPath(**dict(conf.items()))
540 return Channel(**dict(conf.items()))
541
542
543 def read_config(filename: str) -> configparser.ConfigParser:
544 config = configparser.ConfigParser()
545 config.read_file(open(filename), filename)
546 return config
547
548
549 def read_config_files(
550 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
551 merged_config: Dict[str, configparser.SectionProxy] = {}
552 for file in filenames:
553 config = read_config(file)
554 for section in config.sections():
555 if section in merged_config:
556 raise Exception('Duplicate channel "%s"' % section)
557 merged_config[section] = config[section]
558 return merged_config
559
560
561 def pin(args: argparse.Namespace) -> None:
562 v = Verification()
563 config = read_config(args.channels_file)
564 for section in config.sections():
565 if args.channels and section not in args.channels:
566 continue
567
568 sp = read_search_path(config[section])
569
570 sp.pin(v, config[section])
571
572 with open(args.channels_file, 'w') as configfile:
573 config.write(configfile)
574
575
576 def update(args: argparse.Namespace) -> None:
577 v = Verification()
578 exprs: Dict[str, str] = {}
579 config = read_config_files(args.channels_file)
580 for section in config:
581 if 'alias_of' in config[section]:
582 assert 'git_repo' not in config[section]
583 continue
584 sp = read_search_path(config[section])
585 tarball = sp.fetch(v, section, config[section])
586 exprs[section] = (
587 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
588 (config[section]['release_name'], tarball))
589
590 for section in config:
591 if 'alias_of' in config[section]:
592 exprs[section] = exprs[str(config[section]['alias_of'])]
593
594 command = [
595 'nix-env',
596 '--profile',
597 '/nix/var/nix/profiles/per-user/%s/channels' %
598 getpass.getuser(),
599 '--show-trace',
600 '--file',
601 '<nix/unpack-channel.nix>',
602 '--install',
603 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
604 if args.dry_run:
605 print(' '.join(map(shlex.quote, command)))
606 else:
607 v.status('Installing channels with nix-env')
608 process = subprocess.run(command)
609 v.result(process.returncode == 0)
610
611
612 def main() -> None:
613 parser = argparse.ArgumentParser(prog='pinch')
614 subparsers = parser.add_subparsers(dest='mode', required=True)
615 parser_pin = subparsers.add_parser('pin')
616 parser_pin.add_argument('channels_file', type=str)
617 parser_pin.add_argument('channels', type=str, nargs='*')
618 parser_pin.set_defaults(func=pin)
619 parser_update = subparsers.add_parser('update')
620 parser_update.add_argument('--dry-run', action='store_true')
621 parser_update.add_argument('channels_file', type=str, nargs='+')
622 parser_update.set_defaults(func=update)
623 args = parser.parse_args()
624 args.func(args)
625
626
627 if __name__ == '__main__':
628 main()