]> git.scottworley.com Git - pinch/blame_incremental - pinch.py
This tool is named "pinch" now
[pinch] / pinch.py
... / ...
CommitLineData
1import argparse
2import configparser
3import filecmp
4import functools
5import getpass
6import hashlib
7import operator
8import os
9import os.path
10import shutil
11import subprocess
12import tempfile
13import types
14import urllib.parse
15import urllib.request
16import xml.dom.minidom
17
18from typing import (
19 Dict,
20 Iterable,
21 List,
22 NewType,
23 Tuple,
24)
25
26Digest16 = NewType('Digest16', str)
27Digest32 = NewType('Digest32', str)
28
29
30class ChannelTableEntry(types.SimpleNamespace):
31 absolute_url: str
32 digest: Digest16
33 file: str
34 size: int
35 url: str
36
37
38class Channel(types.SimpleNamespace):
39 channel_html: bytes
40 channel_url: str
41 forwarded_url: str
42 git_ref: str
43 git_repo: str
44 git_revision: str
45 old_git_revision: str
46 release_name: str
47 table: Dict[str, ChannelTableEntry]
48
49
50class VerificationError(Exception):
51 pass
52
53
54class Verification:
55
56 def __init__(self) -> None:
57 self.line_length = 0
58
59 def status(self, s: str) -> None:
60 print(s, end=' ', flush=True)
61 self.line_length += 1 + len(s) # Unicode??
62
63 @staticmethod
64 def _color(s: str, c: int) -> str:
65 return '\033[%2dm%s\033[00m' % (c, s)
66
67 def result(self, r: bool) -> None:
68 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
69 length = len(message)
70 cols = shutil.get_terminal_size().columns
71 pad = (cols - (self.line_length + length)) % cols
72 print(' ' * pad + self._color(message, color))
73 self.line_length = 0
74 if not r:
75 raise VerificationError()
76
77 def check(self, s: str, r: bool) -> None:
78 self.status(s)
79 self.result(r)
80
81 def ok(self) -> None:
82 self.result(True)
83
84
85def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
86
87 def throw(error: OSError) -> None:
88 raise error
89
90 def join(x: str, y: str) -> str:
91 return y if x == '.' else os.path.join(x, y)
92
93 def recursive_files(d: str) -> Iterable[str]:
94 all_files: List[str] = []
95 for path, dirs, files in os.walk(d, onerror=throw):
96 rel = os.path.relpath(path, start=d)
97 all_files.extend(join(rel, f) for f in files)
98 for dir_or_link in dirs:
99 if os.path.islink(join(path, dir_or_link)):
100 all_files.append(join(rel, dir_or_link))
101 return all_files
102
103 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
104 return (f for f in files if not f.startswith('.git/'))
105
106 files = functools.reduce(
107 operator.or_, (set(
108 exclude_dot_git(
109 recursive_files(x))) for x in [a, b]))
110 return filecmp.cmpfiles(a, b, files, shallow=False)
111
112
113def fetch(v: Verification, channel: Channel) -> None:
114 v.status('Fetching channel')
115 request = urllib.request.urlopen(channel.channel_url, timeout=10)
116 channel.channel_html = request.read()
117 channel.forwarded_url = request.geturl()
118 v.result(request.status == 200)
119 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
120
121
122def parse_channel(v: Verification, channel: Channel) -> None:
123 v.status('Parsing channel description as XML')
124 d = xml.dom.minidom.parseString(channel.channel_html)
125 v.ok()
126
127 v.status('Extracting release name:')
128 title_name = d.getElementsByTagName(
129 'title')[0].firstChild.nodeValue.split()[2]
130 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
131 v.status(title_name)
132 v.result(title_name == h1_name)
133 channel.release_name = title_name
134
135 v.status('Extracting git commit:')
136 git_commit_node = d.getElementsByTagName('tt')[0]
137 channel.git_revision = git_commit_node.firstChild.nodeValue
138 v.status(channel.git_revision)
139 v.ok()
140 v.status('Verifying git commit label')
141 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
142
143 v.status('Parsing table')
144 channel.table = {}
145 for row in d.getElementsByTagName('tr')[1:]:
146 name = row.childNodes[0].firstChild.firstChild.nodeValue
147 url = row.childNodes[0].firstChild.getAttribute('href')
148 size = int(row.childNodes[1].firstChild.nodeValue)
149 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
150 channel.table[name] = ChannelTableEntry(
151 url=url, digest=digest, size=size)
152 v.ok()
153
154
155def digest_string(s: bytes) -> Digest16:
156 return Digest16(hashlib.sha256(s).hexdigest())
157
158
159def digest_file(filename: str) -> Digest16:
160 hasher = hashlib.sha256()
161 with open(filename, 'rb') as f:
162 # pylint: disable=cell-var-from-loop
163 for block in iter(lambda: f.read(4096), b''):
164 hasher.update(block)
165 return Digest16(hasher.hexdigest())
166
167
168def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
169 v.status('Converting digest to base16')
170 process = subprocess.run(
171 ['nix', 'to-base16', '--type', 'sha256', digest32], capture_output=True)
172 v.result(process.returncode == 0)
173 return Digest16(process.stdout.decode().strip())
174
175
176def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
177 v.status('Converting digest to base32')
178 process = subprocess.run(
179 ['nix', 'to-base32', '--type', 'sha256', digest16], capture_output=True)
180 v.result(process.returncode == 0)
181 return Digest32(process.stdout.decode().strip())
182
183
184def fetch_with_nix_prefetch_url(
185 v: Verification,
186 url: str,
187 digest: Digest16) -> str:
188 v.status('Fetching %s' % url)
189 process = subprocess.run(
190 ['nix-prefetch-url', '--print-path', url, digest], capture_output=True)
191 v.result(process.returncode == 0)
192 prefetch_digest, path, empty = process.stdout.decode().split('\n')
193 assert empty == ''
194 v.check("Verifying nix-prefetch-url's digest",
195 to_Digest16(v, Digest32(prefetch_digest)) == digest)
196 v.status("Verifying file digest")
197 file_digest = digest_file(path)
198 v.result(file_digest == digest)
199 return path
200
201
202def fetch_resources(v: Verification, channel: Channel) -> None:
203 for resource in ['git-revision', 'nixexprs.tar.xz']:
204 fields = channel.table[resource]
205 fields.absolute_url = urllib.parse.urljoin(
206 channel.forwarded_url, fields.url)
207 fields.file = fetch_with_nix_prefetch_url(
208 v, fields.absolute_url, fields.digest)
209 v.status('Verifying git commit on main page matches git commit in table')
210 v.result(
211 open(
212 channel.table['git-revision'].file).read(999) == channel.git_revision)
213
214
215def git_cachedir(git_repo: str) -> str:
216 # TODO: Consider using pyxdg to find this path.
217 return os.path.expanduser(
218 '~/.cache/pinch/git/%s' %
219 digest_string(
220 git_repo.encode()))
221
222
223def verify_git_ancestry(v: Verification, channel: Channel) -> None:
224 cachedir = git_cachedir(channel.git_repo)
225 v.status('Verifying rev is an ancestor of ref')
226 process = subprocess.run(['git',
227 '-C',
228 cachedir,
229 'merge-base',
230 '--is-ancestor',
231 channel.git_revision,
232 channel.git_ref])
233 v.result(process.returncode == 0)
234
235 if hasattr(channel, 'old_git_revision'):
236 v.status(
237 'Verifying rev is an ancestor of previous rev %s' %
238 channel.old_git_revision)
239 process = subprocess.run(['git',
240 '-C',
241 cachedir,
242 'merge-base',
243 '--is-ancestor',
244 channel.old_git_revision,
245 channel.git_revision])
246 v.result(process.returncode == 0)
247
248
249def git_fetch(v: Verification, channel: Channel) -> None:
250 # It would be nice if we could share the nix git cache, but as of the time
251 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
252 # yet), and trying to straddle them both is too far into nix implementation
253 # details for my comfort. So we re-implement here half of nix.fetchGit.
254 # :(
255
256 cachedir = git_cachedir(channel.git_repo)
257 if not os.path.exists(cachedir):
258 v.status("Initializing git repo")
259 process = subprocess.run(
260 ['git', 'init', '--bare', cachedir])
261 v.result(process.returncode == 0)
262
263 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
264 # We don't use --force here because we want to abort and freak out if forced
265 # updates are happening.
266 process = subprocess.run(['git',
267 '-C',
268 cachedir,
269 'fetch',
270 channel.git_repo,
271 '%s:%s' % (channel.git_ref,
272 channel.git_ref)])
273 v.result(process.returncode == 0)
274
275 if hasattr(channel, 'git_revision'):
276 v.status('Verifying that fetch retrieved this rev')
277 process = subprocess.run(
278 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
279 v.result(process.returncode == 0)
280 else:
281 channel.git_revision = open(
282 os.path.join(
283 cachedir,
284 'refs',
285 'heads',
286 channel.git_ref)).read(999).strip()
287
288 verify_git_ancestry(v, channel)
289
290
291def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
292 cachedir = git_cachedir(channel.git_repo)
293 if os.path.exists(cachedir):
294 v.status('Checking if we already have this rev:')
295 process = subprocess.run(
296 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
297 if process.returncode == 0:
298 v.status('yes')
299 if process.returncode == 1:
300 v.status('no')
301 v.result(process.returncode == 0 or process.returncode == 1)
302 if process.returncode == 0:
303 verify_git_ancestry(v, channel)
304 return
305 git_fetch(v, channel)
306
307
308def compare_tarball_and_git(
309 v: Verification,
310 channel: Channel,
311 channel_contents: str,
312 git_contents: str) -> None:
313 v.status('Comparing channel tarball with git checkout')
314 match, mismatch, errors = compare(os.path.join(
315 channel_contents, channel.release_name), git_contents)
316 v.ok()
317 v.check('%d files match' % len(match), len(match) > 0)
318 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
319 expected_errors = [
320 '.git-revision',
321 '.version-suffix',
322 'nixpkgs',
323 'programs.sqlite',
324 'svn-revision']
325 benign_errors = []
326 for ee in expected_errors:
327 if ee in errors:
328 errors.remove(ee)
329 benign_errors.append(ee)
330 v.check(
331 '%d unexpected incomparable files' %
332 len(errors),
333 len(errors) == 0)
334 v.check(
335 '(%d of %d expected incomparable files)' %
336 (len(benign_errors),
337 len(expected_errors)),
338 len(benign_errors) == len(expected_errors))
339
340
341def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
342 v.status('Extracting tarball %s' %
343 channel.table['nixexprs.tar.xz'].file)
344 shutil.unpack_archive(
345 channel.table['nixexprs.tar.xz'].file,
346 dest)
347 v.ok()
348
349
350def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
351 v.status('Checking out corresponding git revision')
352 git = subprocess.Popen(['git',
353 '-C',
354 git_cachedir(channel.git_repo),
355 'archive',
356 channel.git_revision],
357 stdout=subprocess.PIPE)
358 tar = subprocess.Popen(
359 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
360 git.stdout.close()
361 tar.wait()
362 git.wait()
363 v.result(git.returncode == 0 and tar.returncode == 0)
364
365
366def git_get_tarball(v: Verification, channel: Channel) -> str:
367 with tempfile.TemporaryDirectory() as output_dir:
368 output_filename = os.path.join(
369 output_dir, channel.release_name + '.tar.xz')
370 with open(output_filename, 'w') as output_file:
371 v.status(
372 'Generating tarball for git revision %s' %
373 channel.git_revision)
374 git = subprocess.Popen(['git',
375 '-C',
376 git_cachedir(channel.git_repo),
377 'archive',
378 '--prefix=%s/' % channel.release_name,
379 channel.git_revision],
380 stdout=subprocess.PIPE)
381 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
382 xz.wait()
383 git.wait()
384 v.result(git.returncode == 0 and xz.returncode == 0)
385
386 v.status('Putting tarball in Nix store')
387 process = subprocess.run(
388 ['nix-store', '--add', output_filename], capture_output=True)
389 v.result(process.returncode == 0)
390 return process.stdout.decode().strip()
391
392
393def check_channel_metadata(
394 v: Verification,
395 channel: Channel,
396 channel_contents: str) -> None:
397 v.status('Verifying git commit in channel tarball')
398 v.result(
399 open(
400 os.path.join(
401 channel_contents,
402 channel.release_name,
403 '.git-revision')).read(999) == channel.git_revision)
404
405 v.status(
406 'Verifying version-suffix is a suffix of release name %s:' %
407 channel.release_name)
408 version_suffix = open(
409 os.path.join(
410 channel_contents,
411 channel.release_name,
412 '.version-suffix')).read(999)
413 v.status(version_suffix)
414 v.result(channel.release_name.endswith(version_suffix))
415
416
417def check_channel_contents(v: Verification, channel: Channel) -> None:
418 with tempfile.TemporaryDirectory() as channel_contents, \
419 tempfile.TemporaryDirectory() as git_contents:
420
421 extract_tarball(v, channel, channel_contents)
422 check_channel_metadata(v, channel, channel_contents)
423
424 git_checkout(v, channel, git_contents)
425
426 compare_tarball_and_git(v, channel, channel_contents, git_contents)
427
428 v.status('Removing temporary directories')
429 v.ok()
430
431
432def pin_channel(v: Verification, channel: Channel) -> None:
433 fetch(v, channel)
434 parse_channel(v, channel)
435 fetch_resources(v, channel)
436 ensure_git_rev_available(v, channel)
437 check_channel_contents(v, channel)
438
439
440def git_revision_name(v: Verification, channel: Channel) -> str:
441 v.status('Getting commit date')
442 process = subprocess.run(['git',
443 '-C',
444 git_cachedir(channel.git_repo),
445 'lo',
446 '-n1',
447 '--format=%ct-%h',
448 '--abbrev=11',
449 channel.git_revision],
450 capture_output=True)
451 v.result(process.returncode == 0 and process.stdout != '')
452 return '%s-%s' % (os.path.basename(channel.git_repo),
453 process.stdout.decode().strip())
454
455
456def pin(args: argparse.Namespace) -> None:
457 v = Verification()
458 config = configparser.ConfigParser()
459 config.read_file(open(args.channels_file), args.channels_file)
460 for section in config.sections():
461
462 channel = Channel(**dict(config[section].items()))
463 if hasattr(channel, 'git_revision'):
464 channel.old_git_revision = channel.git_revision
465 del channel.git_revision
466
467 if 'channel_url' in config[section]:
468 pin_channel(v, channel)
469 config[section]['release_name'] = channel.release_name
470 config[section]['tarball_url'] = channel.table['nixexprs.tar.xz'].absolute_url
471 config[section]['tarball_sha256'] = channel.table['nixexprs.tar.xz'].digest
472 else:
473 git_fetch(v, channel)
474 config[section]['release_name'] = git_revision_name(v, channel)
475 config[section]['git_revision'] = channel.git_revision
476
477 with open(args.channels_file, 'w') as configfile:
478 config.write(configfile)
479
480
481def update(args: argparse.Namespace) -> None:
482 v = Verification()
483 config = configparser.ConfigParser()
484 config.read_file(open(args.channels_file), args.channels_file)
485 exprs = []
486 for section in config.sections():
487 if 'channel_url' in config[section]:
488 tarball = fetch_with_nix_prefetch_url(
489 v, config[section]['tarball_url'], Digest16(
490 config[section]['tarball_sha256']))
491 else:
492 channel = Channel(**dict(config[section].items()))
493 ensure_git_rev_available(v, channel)
494 tarball = git_get_tarball(v, channel)
495 exprs.append(
496 'f: f { name = "%s"; channelName = "%s"; src = builtins.storePath "%s"; }' %
497 (config[section]['release_name'], section, tarball))
498 v.status('Installing channels with nix-env')
499 process = subprocess.run(
500 [
501 'nix-env',
502 '--profile',
503 '/nix/var/nix/profiles/per-user/%s/channels' %
504 getpass.getuser(),
505 '--show-trace',
506 '--file',
507 '<nix/unpack-channel.nix>',
508 '--install',
509 '--from-expression'] +
510 exprs)
511 v.result(process.returncode == 0)
512
513
514def main() -> None:
515 parser = argparse.ArgumentParser(prog='pinch')
516 subparsers = parser.add_subparsers(dest='mode', required=True)
517 parser_pin = subparsers.add_parser('pin')
518 parser_pin.add_argument('channels_file', type=str)
519 parser_pin.set_defaults(func=pin)
520 parser_update = subparsers.add_parser('update')
521 parser_update.add_argument('channels_file', type=str)
522 parser_update.set_defaults(func=update)
523 args = parser.parse_args()
524 args.func(args)
525
526
527main()