]> git.scottworley.com Git - pinch/blob - pinch.py
Add --dry_run flag
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import tempfile
14 import types
15 import urllib.parse
16 import urllib.request
17 import xml.dom.minidom
18
19 from typing import (
20 Dict,
21 Iterable,
22 List,
23 NewType,
24 Tuple,
25 )
26
27 Digest16 = NewType('Digest16', str)
28 Digest32 = NewType('Digest32', str)
29
30
31 class ChannelTableEntry(types.SimpleNamespace):
32 absolute_url: str
33 digest: Digest16
34 file: str
35 size: int
36 url: str
37
38
39 class Channel(types.SimpleNamespace):
40 channel_html: bytes
41 channel_url: str
42 forwarded_url: str
43 git_ref: str
44 git_repo: str
45 git_revision: str
46 old_git_revision: str
47 release_name: str
48 table: Dict[str, ChannelTableEntry]
49
50
51 class VerificationError(Exception):
52 pass
53
54
55 class Verification:
56
57 def __init__(self) -> None:
58 self.line_length = 0
59
60 def status(self, s: str) -> None:
61 print(s, end=' ', flush=True)
62 self.line_length += 1 + len(s) # Unicode??
63
64 @staticmethod
65 def _color(s: str, c: int) -> str:
66 return '\033[%2dm%s\033[00m' % (c, s)
67
68 def result(self, r: bool) -> None:
69 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
70 length = len(message)
71 cols = shutil.get_terminal_size().columns
72 pad = (cols - (self.line_length + length)) % cols
73 print(' ' * pad + self._color(message, color))
74 self.line_length = 0
75 if not r:
76 raise VerificationError()
77
78 def check(self, s: str, r: bool) -> None:
79 self.status(s)
80 self.result(r)
81
82 def ok(self) -> None:
83 self.result(True)
84
85
86 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
87
88 def throw(error: OSError) -> None:
89 raise error
90
91 def join(x: str, y: str) -> str:
92 return y if x == '.' else os.path.join(x, y)
93
94 def recursive_files(d: str) -> Iterable[str]:
95 all_files: List[str] = []
96 for path, dirs, files in os.walk(d, onerror=throw):
97 rel = os.path.relpath(path, start=d)
98 all_files.extend(join(rel, f) for f in files)
99 for dir_or_link in dirs:
100 if os.path.islink(join(path, dir_or_link)):
101 all_files.append(join(rel, dir_or_link))
102 return all_files
103
104 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
105 return (f for f in files if not f.startswith('.git/'))
106
107 files = functools.reduce(
108 operator.or_, (set(
109 exclude_dot_git(
110 recursive_files(x))) for x in [a, b]))
111 return filecmp.cmpfiles(a, b, files, shallow=False)
112
113
114 def fetch(v: Verification, channel: Channel) -> None:
115 v.status('Fetching channel')
116 request = urllib.request.urlopen(channel.channel_url, timeout=10)
117 channel.channel_html = request.read()
118 channel.forwarded_url = request.geturl()
119 v.result(request.status == 200)
120 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
121
122
123 def parse_channel(v: Verification, channel: Channel) -> None:
124 v.status('Parsing channel description as XML')
125 d = xml.dom.minidom.parseString(channel.channel_html)
126 v.ok()
127
128 v.status('Extracting release name:')
129 title_name = d.getElementsByTagName(
130 'title')[0].firstChild.nodeValue.split()[2]
131 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
132 v.status(title_name)
133 v.result(title_name == h1_name)
134 channel.release_name = title_name
135
136 v.status('Extracting git commit:')
137 git_commit_node = d.getElementsByTagName('tt')[0]
138 channel.git_revision = git_commit_node.firstChild.nodeValue
139 v.status(channel.git_revision)
140 v.ok()
141 v.status('Verifying git commit label')
142 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
143
144 v.status('Parsing table')
145 channel.table = {}
146 for row in d.getElementsByTagName('tr')[1:]:
147 name = row.childNodes[0].firstChild.firstChild.nodeValue
148 url = row.childNodes[0].firstChild.getAttribute('href')
149 size = int(row.childNodes[1].firstChild.nodeValue)
150 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
151 channel.table[name] = ChannelTableEntry(
152 url=url, digest=digest, size=size)
153 v.ok()
154
155
156 def digest_string(s: bytes) -> Digest16:
157 return Digest16(hashlib.sha256(s).hexdigest())
158
159
160 def digest_file(filename: str) -> Digest16:
161 hasher = hashlib.sha256()
162 with open(filename, 'rb') as f:
163 # pylint: disable=cell-var-from-loop
164 for block in iter(lambda: f.read(4096), b''):
165 hasher.update(block)
166 return Digest16(hasher.hexdigest())
167
168
169 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
170 v.status('Converting digest to base16')
171 process = subprocess.run(
172 ['nix', 'to-base16', '--type', 'sha256', digest32], capture_output=True)
173 v.result(process.returncode == 0)
174 return Digest16(process.stdout.decode().strip())
175
176
177 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
178 v.status('Converting digest to base32')
179 process = subprocess.run(
180 ['nix', 'to-base32', '--type', 'sha256', digest16], capture_output=True)
181 v.result(process.returncode == 0)
182 return Digest32(process.stdout.decode().strip())
183
184
185 def fetch_with_nix_prefetch_url(
186 v: Verification,
187 url: str,
188 digest: Digest16) -> str:
189 v.status('Fetching %s' % url)
190 process = subprocess.run(
191 ['nix-prefetch-url', '--print-path', url, digest], capture_output=True)
192 v.result(process.returncode == 0)
193 prefetch_digest, path, empty = process.stdout.decode().split('\n')
194 assert empty == ''
195 v.check("Verifying nix-prefetch-url's digest",
196 to_Digest16(v, Digest32(prefetch_digest)) == digest)
197 v.status("Verifying file digest")
198 file_digest = digest_file(path)
199 v.result(file_digest == digest)
200 return path
201
202
203 def fetch_resources(v: Verification, channel: Channel) -> None:
204 for resource in ['git-revision', 'nixexprs.tar.xz']:
205 fields = channel.table[resource]
206 fields.absolute_url = urllib.parse.urljoin(
207 channel.forwarded_url, fields.url)
208 fields.file = fetch_with_nix_prefetch_url(
209 v, fields.absolute_url, fields.digest)
210 v.status('Verifying git commit on main page matches git commit in table')
211 v.result(
212 open(
213 channel.table['git-revision'].file).read(999) == channel.git_revision)
214
215
216 def git_cachedir(git_repo: str) -> str:
217 # TODO: Consider using pyxdg to find this path.
218 return os.path.expanduser(
219 '~/.cache/pinch/git/%s' %
220 digest_string(
221 git_repo.encode()))
222
223
224 def verify_git_ancestry(v: Verification, channel: Channel) -> None:
225 cachedir = git_cachedir(channel.git_repo)
226 v.status('Verifying rev is an ancestor of ref')
227 process = subprocess.run(['git',
228 '-C',
229 cachedir,
230 'merge-base',
231 '--is-ancestor',
232 channel.git_revision,
233 channel.git_ref])
234 v.result(process.returncode == 0)
235
236 if hasattr(channel, 'old_git_revision'):
237 v.status(
238 'Verifying rev is an ancestor of previous rev %s' %
239 channel.old_git_revision)
240 process = subprocess.run(['git',
241 '-C',
242 cachedir,
243 'merge-base',
244 '--is-ancestor',
245 channel.old_git_revision,
246 channel.git_revision])
247 v.result(process.returncode == 0)
248
249
250 def git_fetch(v: Verification, channel: Channel) -> None:
251 # It would be nice if we could share the nix git cache, but as of the time
252 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
253 # yet), and trying to straddle them both is too far into nix implementation
254 # details for my comfort. So we re-implement here half of nix.fetchGit.
255 # :(
256
257 cachedir = git_cachedir(channel.git_repo)
258 if not os.path.exists(cachedir):
259 v.status("Initializing git repo")
260 process = subprocess.run(
261 ['git', 'init', '--bare', cachedir])
262 v.result(process.returncode == 0)
263
264 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
265 # We don't use --force here because we want to abort and freak out if forced
266 # updates are happening.
267 process = subprocess.run(['git',
268 '-C',
269 cachedir,
270 'fetch',
271 channel.git_repo,
272 '%s:%s' % (channel.git_ref,
273 channel.git_ref)])
274 v.result(process.returncode == 0)
275
276 if hasattr(channel, 'git_revision'):
277 v.status('Verifying that fetch retrieved this rev')
278 process = subprocess.run(
279 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
280 v.result(process.returncode == 0)
281 else:
282 channel.git_revision = open(
283 os.path.join(
284 cachedir,
285 'refs',
286 'heads',
287 channel.git_ref)).read(999).strip()
288
289 verify_git_ancestry(v, channel)
290
291
292 def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
293 cachedir = git_cachedir(channel.git_repo)
294 if os.path.exists(cachedir):
295 v.status('Checking if we already have this rev:')
296 process = subprocess.run(
297 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
298 if process.returncode == 0:
299 v.status('yes')
300 if process.returncode == 1:
301 v.status('no')
302 v.result(process.returncode == 0 or process.returncode == 1)
303 if process.returncode == 0:
304 verify_git_ancestry(v, channel)
305 return
306 git_fetch(v, channel)
307
308
309 def compare_tarball_and_git(
310 v: Verification,
311 channel: Channel,
312 channel_contents: str,
313 git_contents: str) -> None:
314 v.status('Comparing channel tarball with git checkout')
315 match, mismatch, errors = compare(os.path.join(
316 channel_contents, channel.release_name), git_contents)
317 v.ok()
318 v.check('%d files match' % len(match), len(match) > 0)
319 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
320 expected_errors = [
321 '.git-revision',
322 '.version-suffix',
323 'nixpkgs',
324 'programs.sqlite',
325 'svn-revision']
326 benign_errors = []
327 for ee in expected_errors:
328 if ee in errors:
329 errors.remove(ee)
330 benign_errors.append(ee)
331 v.check(
332 '%d unexpected incomparable files' %
333 len(errors),
334 len(errors) == 0)
335 v.check(
336 '(%d of %d expected incomparable files)' %
337 (len(benign_errors),
338 len(expected_errors)),
339 len(benign_errors) == len(expected_errors))
340
341
342 def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
343 v.status('Extracting tarball %s' %
344 channel.table['nixexprs.tar.xz'].file)
345 shutil.unpack_archive(
346 channel.table['nixexprs.tar.xz'].file,
347 dest)
348 v.ok()
349
350
351 def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
352 v.status('Checking out corresponding git revision')
353 git = subprocess.Popen(['git',
354 '-C',
355 git_cachedir(channel.git_repo),
356 'archive',
357 channel.git_revision],
358 stdout=subprocess.PIPE)
359 tar = subprocess.Popen(
360 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
361 git.stdout.close()
362 tar.wait()
363 git.wait()
364 v.result(git.returncode == 0 and tar.returncode == 0)
365
366
367 def git_get_tarball(v: Verification, channel: Channel) -> str:
368 with tempfile.TemporaryDirectory() as output_dir:
369 output_filename = os.path.join(
370 output_dir, channel.release_name + '.tar.xz')
371 with open(output_filename, 'w') as output_file:
372 v.status(
373 'Generating tarball for git revision %s' %
374 channel.git_revision)
375 git = subprocess.Popen(['git',
376 '-C',
377 git_cachedir(channel.git_repo),
378 'archive',
379 '--prefix=%s/' % channel.release_name,
380 channel.git_revision],
381 stdout=subprocess.PIPE)
382 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
383 xz.wait()
384 git.wait()
385 v.result(git.returncode == 0 and xz.returncode == 0)
386
387 v.status('Putting tarball in Nix store')
388 process = subprocess.run(
389 ['nix-store', '--add', output_filename], capture_output=True)
390 v.result(process.returncode == 0)
391 return process.stdout.decode().strip()
392
393
394 def check_channel_metadata(
395 v: Verification,
396 channel: Channel,
397 channel_contents: str) -> None:
398 v.status('Verifying git commit in channel tarball')
399 v.result(
400 open(
401 os.path.join(
402 channel_contents,
403 channel.release_name,
404 '.git-revision')).read(999) == channel.git_revision)
405
406 v.status(
407 'Verifying version-suffix is a suffix of release name %s:' %
408 channel.release_name)
409 version_suffix = open(
410 os.path.join(
411 channel_contents,
412 channel.release_name,
413 '.version-suffix')).read(999)
414 v.status(version_suffix)
415 v.result(channel.release_name.endswith(version_suffix))
416
417
418 def check_channel_contents(v: Verification, channel: Channel) -> None:
419 with tempfile.TemporaryDirectory() as channel_contents, \
420 tempfile.TemporaryDirectory() as git_contents:
421
422 extract_tarball(v, channel, channel_contents)
423 check_channel_metadata(v, channel, channel_contents)
424
425 git_checkout(v, channel, git_contents)
426
427 compare_tarball_and_git(v, channel, channel_contents, git_contents)
428
429 v.status('Removing temporary directories')
430 v.ok()
431
432
433 def pin_channel(v: Verification, channel: Channel) -> None:
434 fetch(v, channel)
435 parse_channel(v, channel)
436 fetch_resources(v, channel)
437 ensure_git_rev_available(v, channel)
438 check_channel_contents(v, channel)
439
440
441 def git_revision_name(v: Verification, channel: Channel) -> str:
442 v.status('Getting commit date')
443 process = subprocess.run(['git',
444 '-C',
445 git_cachedir(channel.git_repo),
446 'lo',
447 '-n1',
448 '--format=%ct-%h',
449 '--abbrev=11',
450 channel.git_revision],
451 capture_output=True)
452 v.result(process.returncode == 0 and process.stdout != '')
453 return '%s-%s' % (os.path.basename(channel.git_repo),
454 process.stdout.decode().strip())
455
456
457 def pin(args: argparse.Namespace) -> None:
458 v = Verification()
459 config = configparser.ConfigParser()
460 config.read_file(open(args.channels_file), args.channels_file)
461 for section in config.sections():
462 if args.channels and section not in args.channels:
463 continue
464
465 channel = Channel(**dict(config[section].items()))
466 if hasattr(channel, 'git_revision'):
467 channel.old_git_revision = channel.git_revision
468 del channel.git_revision
469
470 if 'channel_url' in config[section]:
471 pin_channel(v, channel)
472 config[section]['release_name'] = channel.release_name
473 config[section]['tarball_url'] = channel.table['nixexprs.tar.xz'].absolute_url
474 config[section]['tarball_sha256'] = channel.table['nixexprs.tar.xz'].digest
475 else:
476 git_fetch(v, channel)
477 config[section]['release_name'] = git_revision_name(v, channel)
478 config[section]['git_revision'] = channel.git_revision
479
480 with open(args.channels_file, 'w') as configfile:
481 config.write(configfile)
482
483
484 def update(args: argparse.Namespace) -> None:
485 v = Verification()
486 config = configparser.ConfigParser()
487 config.read_file(open(args.channels_file), args.channels_file)
488 exprs = []
489 for section in config.sections():
490 if 'channel_url' in config[section]:
491 tarball = fetch_with_nix_prefetch_url(
492 v, config[section]['tarball_url'], Digest16(
493 config[section]['tarball_sha256']))
494 else:
495 channel = Channel(**dict(config[section].items()))
496 ensure_git_rev_available(v, channel)
497 tarball = git_get_tarball(v, channel)
498 exprs.append(
499 'f: f { name = "%s"; channelName = "%s"; src = builtins.storePath "%s"; }' %
500 (config[section]['release_name'], section, tarball))
501 command = [
502 'nix-env',
503 '--profile',
504 '/nix/var/nix/profiles/per-user/%s/channels' %
505 getpass.getuser(),
506 '--show-trace',
507 '--file',
508 '<nix/unpack-channel.nix>',
509 '--install',
510 '--from-expression'] + exprs
511 if args.dry_run:
512 print(' '.join(map(shlex.quote, command)))
513 else:
514 v.status('Installing channels with nix-env')
515 process = subprocess.run(command)
516 v.result(process.returncode == 0)
517
518
519 def main() -> None:
520 parser = argparse.ArgumentParser(prog='pinch')
521 subparsers = parser.add_subparsers(dest='mode', required=True)
522 parser_pin = subparsers.add_parser('pin')
523 parser_pin.add_argument('channels_file', type=str)
524 parser_pin.add_argument('channels', type=str, nargs='*')
525 parser_pin.set_defaults(func=pin)
526 parser_update = subparsers.add_parser('update')
527 parser_update.add_argument('--dry-run', action='store_true')
528 parser_update.add_argument('channels_file', type=str)
529 parser_update.set_defaults(func=update)
530 args = parser.parse_args()
531 args.func(args)
532
533
534 main()