]> git.scottworley.com Git - pinch/blob - pinch.py
Use xdg packge to find XDG cache dir
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tempfile
15 import types
16 import urllib.parse
17 import urllib.request
18 import xml.dom.minidom
19
20 from typing import (
21 Dict,
22 Iterable,
23 List,
24 NewType,
25 Tuple,
26 )
27
28 import xdg
29
30
31 Digest16 = NewType('Digest16', str)
32 Digest32 = NewType('Digest32', str)
33
34
35 class ChannelTableEntry(types.SimpleNamespace):
36 absolute_url: str
37 digest: Digest16
38 file: str
39 size: int
40 url: str
41
42
43 class Channel(types.SimpleNamespace):
44 alias_of: str
45 channel_html: bytes
46 channel_url: str
47 forwarded_url: str
48 git_ref: str
49 git_repo: str
50 git_revision: str
51 old_git_revision: str
52 release_name: str
53 table: Dict[str, ChannelTableEntry]
54
55
56 class VerificationError(Exception):
57 pass
58
59
60 class Verification:
61
62 def __init__(self) -> None:
63 self.line_length = 0
64
65 def status(self, s: str) -> None:
66 print(s, end=' ', file=sys.stderr, flush=True)
67 self.line_length += 1 + len(s) # Unicode??
68
69 @staticmethod
70 def _color(s: str, c: int) -> str:
71 return '\033[%2dm%s\033[00m' % (c, s)
72
73 def result(self, r: bool) -> None:
74 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
75 length = len(message)
76 cols = shutil.get_terminal_size().columns
77 pad = (cols - (self.line_length + length)) % cols
78 print(' ' * pad + self._color(message, color), file=sys.stderr)
79 self.line_length = 0
80 if not r:
81 raise VerificationError()
82
83 def check(self, s: str, r: bool) -> None:
84 self.status(s)
85 self.result(r)
86
87 def ok(self) -> None:
88 self.result(True)
89
90
91 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
92
93 def throw(error: OSError) -> None:
94 raise error
95
96 def join(x: str, y: str) -> str:
97 return y if x == '.' else os.path.join(x, y)
98
99 def recursive_files(d: str) -> Iterable[str]:
100 all_files: List[str] = []
101 for path, dirs, files in os.walk(d, onerror=throw):
102 rel = os.path.relpath(path, start=d)
103 all_files.extend(join(rel, f) for f in files)
104 for dir_or_link in dirs:
105 if os.path.islink(join(path, dir_or_link)):
106 all_files.append(join(rel, dir_or_link))
107 return all_files
108
109 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
110 return (f for f in files if not f.startswith('.git/'))
111
112 files = functools.reduce(
113 operator.or_, (set(
114 exclude_dot_git(
115 recursive_files(x))) for x in [a, b]))
116 return filecmp.cmpfiles(a, b, files, shallow=False)
117
118
119 def fetch(v: Verification, channel: Channel) -> None:
120 v.status('Fetching channel')
121 request = urllib.request.urlopen(channel.channel_url, timeout=10)
122 channel.channel_html = request.read()
123 channel.forwarded_url = request.geturl()
124 v.result(request.status == 200)
125 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
126
127
128 def parse_channel(v: Verification, channel: Channel) -> None:
129 v.status('Parsing channel description as XML')
130 d = xml.dom.minidom.parseString(channel.channel_html)
131 v.ok()
132
133 v.status('Extracting release name:')
134 title_name = d.getElementsByTagName(
135 'title')[0].firstChild.nodeValue.split()[2]
136 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
137 v.status(title_name)
138 v.result(title_name == h1_name)
139 channel.release_name = title_name
140
141 v.status('Extracting git commit:')
142 git_commit_node = d.getElementsByTagName('tt')[0]
143 channel.git_revision = git_commit_node.firstChild.nodeValue
144 v.status(channel.git_revision)
145 v.ok()
146 v.status('Verifying git commit label')
147 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
148
149 v.status('Parsing table')
150 channel.table = {}
151 for row in d.getElementsByTagName('tr')[1:]:
152 name = row.childNodes[0].firstChild.firstChild.nodeValue
153 url = row.childNodes[0].firstChild.getAttribute('href')
154 size = int(row.childNodes[1].firstChild.nodeValue)
155 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
156 channel.table[name] = ChannelTableEntry(
157 url=url, digest=digest, size=size)
158 v.ok()
159
160
161 def digest_string(s: bytes) -> Digest16:
162 return Digest16(hashlib.sha256(s).hexdigest())
163
164
165 def digest_file(filename: str) -> Digest16:
166 hasher = hashlib.sha256()
167 with open(filename, 'rb') as f:
168 # pylint: disable=cell-var-from-loop
169 for block in iter(lambda: f.read(4096), b''):
170 hasher.update(block)
171 return Digest16(hasher.hexdigest())
172
173
174 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
175 v.status('Converting digest to base16')
176 process = subprocess.run(
177 ['nix', 'to-base16', '--type', 'sha256', digest32], capture_output=True)
178 v.result(process.returncode == 0)
179 return Digest16(process.stdout.decode().strip())
180
181
182 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
183 v.status('Converting digest to base32')
184 process = subprocess.run(
185 ['nix', 'to-base32', '--type', 'sha256', digest16], capture_output=True)
186 v.result(process.returncode == 0)
187 return Digest32(process.stdout.decode().strip())
188
189
190 def fetch_with_nix_prefetch_url(
191 v: Verification,
192 url: str,
193 digest: Digest16) -> str:
194 v.status('Fetching %s' % url)
195 process = subprocess.run(
196 ['nix-prefetch-url', '--print-path', url, digest], capture_output=True)
197 v.result(process.returncode == 0)
198 prefetch_digest, path, empty = process.stdout.decode().split('\n')
199 assert empty == ''
200 v.check("Verifying nix-prefetch-url's digest",
201 to_Digest16(v, Digest32(prefetch_digest)) == digest)
202 v.status("Verifying file digest")
203 file_digest = digest_file(path)
204 v.result(file_digest == digest)
205 return path
206
207
208 def fetch_resources(v: Verification, channel: Channel) -> None:
209 for resource in ['git-revision', 'nixexprs.tar.xz']:
210 fields = channel.table[resource]
211 fields.absolute_url = urllib.parse.urljoin(
212 channel.forwarded_url, fields.url)
213 fields.file = fetch_with_nix_prefetch_url(
214 v, fields.absolute_url, fields.digest)
215 v.status('Verifying git commit on main page matches git commit in table')
216 v.result(
217 open(
218 channel.table['git-revision'].file).read(999) == channel.git_revision)
219
220
221 def git_cachedir(git_repo: str) -> str:
222 return os.path.join(
223 xdg.XDG_CACHE_HOME,
224 'pinch/git',
225 digest_string(
226 git_repo.encode()))
227
228
229 def verify_git_ancestry(v: Verification, channel: Channel) -> None:
230 cachedir = git_cachedir(channel.git_repo)
231 v.status('Verifying rev is an ancestor of ref')
232 process = subprocess.run(['git',
233 '-C',
234 cachedir,
235 'merge-base',
236 '--is-ancestor',
237 channel.git_revision,
238 channel.git_ref])
239 v.result(process.returncode == 0)
240
241 if hasattr(channel, 'old_git_revision'):
242 v.status(
243 'Verifying rev is an ancestor of previous rev %s' %
244 channel.old_git_revision)
245 process = subprocess.run(['git',
246 '-C',
247 cachedir,
248 'merge-base',
249 '--is-ancestor',
250 channel.old_git_revision,
251 channel.git_revision])
252 v.result(process.returncode == 0)
253
254
255 def git_fetch(v: Verification, channel: Channel) -> None:
256 # It would be nice if we could share the nix git cache, but as of the time
257 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
258 # yet), and trying to straddle them both is too far into nix implementation
259 # details for my comfort. So we re-implement here half of nix.fetchGit.
260 # :(
261
262 cachedir = git_cachedir(channel.git_repo)
263 if not os.path.exists(cachedir):
264 v.status("Initializing git repo")
265 process = subprocess.run(
266 ['git', 'init', '--bare', cachedir])
267 v.result(process.returncode == 0)
268
269 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
270 # We don't use --force here because we want to abort and freak out if forced
271 # updates are happening.
272 process = subprocess.run(['git',
273 '-C',
274 cachedir,
275 'fetch',
276 channel.git_repo,
277 '%s:%s' % (channel.git_ref,
278 channel.git_ref)])
279 v.result(process.returncode == 0)
280
281 if hasattr(channel, 'git_revision'):
282 v.status('Verifying that fetch retrieved this rev')
283 process = subprocess.run(
284 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
285 v.result(process.returncode == 0)
286 else:
287 channel.git_revision = open(
288 os.path.join(
289 cachedir,
290 'refs',
291 'heads',
292 channel.git_ref)).read(999).strip()
293
294 verify_git_ancestry(v, channel)
295
296
297 def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
298 cachedir = git_cachedir(channel.git_repo)
299 if os.path.exists(cachedir):
300 v.status('Checking if we already have this rev:')
301 process = subprocess.run(
302 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
303 if process.returncode == 0:
304 v.status('yes')
305 if process.returncode == 1:
306 v.status('no')
307 v.result(process.returncode == 0 or process.returncode == 1)
308 if process.returncode == 0:
309 verify_git_ancestry(v, channel)
310 return
311 git_fetch(v, channel)
312
313
314 def compare_tarball_and_git(
315 v: Verification,
316 channel: Channel,
317 channel_contents: str,
318 git_contents: str) -> None:
319 v.status('Comparing channel tarball with git checkout')
320 match, mismatch, errors = compare(os.path.join(
321 channel_contents, channel.release_name), git_contents)
322 v.ok()
323 v.check('%d files match' % len(match), len(match) > 0)
324 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
325 expected_errors = [
326 '.git-revision',
327 '.version-suffix',
328 'nixpkgs',
329 'programs.sqlite',
330 'svn-revision']
331 benign_errors = []
332 for ee in expected_errors:
333 if ee in errors:
334 errors.remove(ee)
335 benign_errors.append(ee)
336 v.check(
337 '%d unexpected incomparable files' %
338 len(errors),
339 len(errors) == 0)
340 v.check(
341 '(%d of %d expected incomparable files)' %
342 (len(benign_errors),
343 len(expected_errors)),
344 len(benign_errors) == len(expected_errors))
345
346
347 def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
348 v.status('Extracting tarball %s' %
349 channel.table['nixexprs.tar.xz'].file)
350 shutil.unpack_archive(
351 channel.table['nixexprs.tar.xz'].file,
352 dest)
353 v.ok()
354
355
356 def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
357 v.status('Checking out corresponding git revision')
358 git = subprocess.Popen(['git',
359 '-C',
360 git_cachedir(channel.git_repo),
361 'archive',
362 channel.git_revision],
363 stdout=subprocess.PIPE)
364 tar = subprocess.Popen(
365 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
366 if git.stdout:
367 git.stdout.close()
368 tar.wait()
369 git.wait()
370 v.result(git.returncode == 0 and tar.returncode == 0)
371
372
373 def git_get_tarball(v: Verification, channel: Channel) -> str:
374 with tempfile.TemporaryDirectory() as output_dir:
375 output_filename = os.path.join(
376 output_dir, channel.release_name + '.tar.xz')
377 with open(output_filename, 'w') as output_file:
378 v.status(
379 'Generating tarball for git revision %s' %
380 channel.git_revision)
381 git = subprocess.Popen(['git',
382 '-C',
383 git_cachedir(channel.git_repo),
384 'archive',
385 '--prefix=%s/' % channel.release_name,
386 channel.git_revision],
387 stdout=subprocess.PIPE)
388 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
389 xz.wait()
390 git.wait()
391 v.result(git.returncode == 0 and xz.returncode == 0)
392
393 v.status('Putting tarball in Nix store')
394 process = subprocess.run(
395 ['nix-store', '--add', output_filename], capture_output=True)
396 v.result(process.returncode == 0)
397 return process.stdout.decode().strip()
398
399
400 def check_channel_metadata(
401 v: Verification,
402 channel: Channel,
403 channel_contents: str) -> None:
404 v.status('Verifying git commit in channel tarball')
405 v.result(
406 open(
407 os.path.join(
408 channel_contents,
409 channel.release_name,
410 '.git-revision')).read(999) == channel.git_revision)
411
412 v.status(
413 'Verifying version-suffix is a suffix of release name %s:' %
414 channel.release_name)
415 version_suffix = open(
416 os.path.join(
417 channel_contents,
418 channel.release_name,
419 '.version-suffix')).read(999)
420 v.status(version_suffix)
421 v.result(channel.release_name.endswith(version_suffix))
422
423
424 def check_channel_contents(v: Verification, channel: Channel) -> None:
425 with tempfile.TemporaryDirectory() as channel_contents, \
426 tempfile.TemporaryDirectory() as git_contents:
427
428 extract_tarball(v, channel, channel_contents)
429 check_channel_metadata(v, channel, channel_contents)
430
431 git_checkout(v, channel, git_contents)
432
433 compare_tarball_and_git(v, channel, channel_contents, git_contents)
434
435 v.status('Removing temporary directories')
436 v.ok()
437
438
439 def pin_channel(v: Verification, channel: Channel) -> None:
440 fetch(v, channel)
441 parse_channel(v, channel)
442 fetch_resources(v, channel)
443 ensure_git_rev_available(v, channel)
444 check_channel_contents(v, channel)
445
446
447 def git_revision_name(v: Verification, channel: Channel) -> str:
448 v.status('Getting commit date')
449 process = subprocess.run(['git',
450 '-C',
451 git_cachedir(channel.git_repo),
452 'lo',
453 '-n1',
454 '--format=%ct-%h',
455 '--abbrev=11',
456 channel.git_revision],
457 capture_output=True)
458 v.result(process.returncode == 0 and process.stdout != b'')
459 return '%s-%s' % (os.path.basename(channel.git_repo),
460 process.stdout.decode().strip())
461
462
463 def read_config(filename: str) -> configparser.ConfigParser:
464 config = configparser.ConfigParser()
465 config.read_file(open(filename), filename)
466 return config
467
468
469 def pin(args: argparse.Namespace) -> None:
470 v = Verification()
471 config = read_config(args.channels_file)
472 for section in config.sections():
473 if args.channels and section not in args.channels:
474 continue
475
476 channel = Channel(**dict(config[section].items()))
477
478 if hasattr(channel, 'alias_of'):
479 assert not hasattr(channel, 'git_repo')
480 continue
481
482 if hasattr(channel, 'git_revision'):
483 channel.old_git_revision = channel.git_revision
484 del channel.git_revision
485
486 if 'channel_url' in config[section]:
487 pin_channel(v, channel)
488 config[section]['release_name'] = channel.release_name
489 config[section]['tarball_url'] = channel.table['nixexprs.tar.xz'].absolute_url
490 config[section]['tarball_sha256'] = channel.table['nixexprs.tar.xz'].digest
491 else:
492 git_fetch(v, channel)
493 config[section]['release_name'] = git_revision_name(v, channel)
494 config[section]['git_revision'] = channel.git_revision
495
496 with open(args.channels_file, 'w') as configfile:
497 config.write(configfile)
498
499
500 def update(args: argparse.Namespace) -> None:
501 v = Verification()
502 config = configparser.ConfigParser()
503 exprs: Dict[str, str] = {}
504 configs = [read_config(filename) for filename in args.channels_file]
505 for config in configs:
506 for section in config.sections():
507
508 if 'alias_of' in config[section]:
509 assert 'git_repo' not in config[section]
510 continue
511
512 if 'channel_url' in config[section]:
513 tarball = fetch_with_nix_prefetch_url(
514 v, config[section]['tarball_url'], Digest16(
515 config[section]['tarball_sha256']))
516 else:
517 channel = Channel(**dict(config[section].items()))
518 ensure_git_rev_available(v, channel)
519 tarball = git_get_tarball(v, channel)
520
521 if section in exprs:
522 raise Exception('Duplicate channel "%s"' % section)
523 exprs[section] = (
524 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
525 (config[section]['release_name'], tarball))
526
527 for config in configs:
528 for section in config.sections():
529 if 'alias_of' in config[section]:
530 if section in exprs:
531 raise Exception('Duplicate channel "%s"' % section)
532 exprs[section] = exprs[str(config[section]['alias_of'])]
533
534 command = [
535 'nix-env',
536 '--profile',
537 '/nix/var/nix/profiles/per-user/%s/channels' %
538 getpass.getuser(),
539 '--show-trace',
540 '--file',
541 '<nix/unpack-channel.nix>',
542 '--install',
543 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
544 if args.dry_run:
545 print(' '.join(map(shlex.quote, command)))
546 else:
547 v.status('Installing channels with nix-env')
548 process = subprocess.run(command)
549 v.result(process.returncode == 0)
550
551
552 def main() -> None:
553 parser = argparse.ArgumentParser(prog='pinch')
554 subparsers = parser.add_subparsers(dest='mode', required=True)
555 parser_pin = subparsers.add_parser('pin')
556 parser_pin.add_argument('channels_file', type=str)
557 parser_pin.add_argument('channels', type=str, nargs='*')
558 parser_pin.set_defaults(func=pin)
559 parser_update = subparsers.add_parser('update')
560 parser_update.add_argument('--dry-run', action='store_true')
561 parser_update.add_argument('channels_file', type=str, nargs='+')
562 parser_update.set_defaults(func=update)
563 args = parser.parse_args()
564 args.func(args)
565
566
567 main()