]> git.scottworley.com Git - pinch/blame_incremental - pinch.py
Split ensure_git_rev_available out from git_fetch
[pinch] / pinch.py
... / ...
CommitLineData
1import argparse
2import configparser
3import filecmp
4import functools
5import hashlib
6import operator
7import os
8import os.path
9import shutil
10import subprocess
11import tempfile
12import types
13import urllib.parse
14import urllib.request
15import xml.dom.minidom
16
17from typing import (
18 Dict,
19 Iterable,
20 List,
21 NewType,
22 Tuple,
23)
24
25Digest16 = NewType('Digest16', str)
26Digest32 = NewType('Digest32', str)
27
28
29class ChannelTableEntry(types.SimpleNamespace):
30 absolute_url: str
31 digest: Digest16
32 file: str
33 size: int
34 url: str
35
36
37class Channel(types.SimpleNamespace):
38 channel_html: bytes
39 channel_url: str
40 forwarded_url: str
41 git_ref: str
42 git_repo: str
43 git_revision: str
44 old_git_revision: str
45 release_name: str
46 table: Dict[str, ChannelTableEntry]
47
48
49class VerificationError(Exception):
50 pass
51
52
53class Verification:
54
55 def __init__(self) -> None:
56 self.line_length = 0
57
58 def status(self, s: str) -> None:
59 print(s, end=' ', flush=True)
60 self.line_length += 1 + len(s) # Unicode??
61
62 @staticmethod
63 def _color(s: str, c: int) -> str:
64 return '\033[%2dm%s\033[00m' % (c, s)
65
66 def result(self, r: bool) -> None:
67 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
68 length = len(message)
69 cols = shutil.get_terminal_size().columns
70 pad = (cols - (self.line_length + length)) % cols
71 print(' ' * pad + self._color(message, color))
72 self.line_length = 0
73 if not r:
74 raise VerificationError()
75
76 def check(self, s: str, r: bool) -> None:
77 self.status(s)
78 self.result(r)
79
80 def ok(self) -> None:
81 self.result(True)
82
83
84def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
85
86 def throw(error: OSError) -> None:
87 raise error
88
89 def join(x: str, y: str) -> str:
90 return y if x == '.' else os.path.join(x, y)
91
92 def recursive_files(d: str) -> Iterable[str]:
93 all_files: List[str] = []
94 for path, dirs, files in os.walk(d, onerror=throw):
95 rel = os.path.relpath(path, start=d)
96 all_files.extend(join(rel, f) for f in files)
97 for dir_or_link in dirs:
98 if os.path.islink(join(path, dir_or_link)):
99 all_files.append(join(rel, dir_or_link))
100 return all_files
101
102 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
103 return (f for f in files if not f.startswith('.git/'))
104
105 files = functools.reduce(
106 operator.or_, (set(
107 exclude_dot_git(
108 recursive_files(x))) for x in [a, b]))
109 return filecmp.cmpfiles(a, b, files, shallow=False)
110
111
112def fetch(v: Verification, channel: Channel) -> None:
113 v.status('Fetching channel')
114 request = urllib.request.urlopen(channel.channel_url, timeout=10)
115 channel.channel_html = request.read()
116 channel.forwarded_url = request.geturl()
117 v.result(request.status == 200)
118 v.check('Got forwarded', channel.channel_url != channel.forwarded_url)
119
120
121def parse_channel(v: Verification, channel: Channel) -> None:
122 v.status('Parsing channel description as XML')
123 d = xml.dom.minidom.parseString(channel.channel_html)
124 v.ok()
125
126 v.status('Extracting release name:')
127 title_name = d.getElementsByTagName(
128 'title')[0].firstChild.nodeValue.split()[2]
129 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
130 v.status(title_name)
131 v.result(title_name == h1_name)
132 channel.release_name = title_name
133
134 v.status('Extracting git commit:')
135 git_commit_node = d.getElementsByTagName('tt')[0]
136 channel.git_revision = git_commit_node.firstChild.nodeValue
137 v.status(channel.git_revision)
138 v.ok()
139 v.status('Verifying git commit label')
140 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
141
142 v.status('Parsing table')
143 channel.table = {}
144 for row in d.getElementsByTagName('tr')[1:]:
145 name = row.childNodes[0].firstChild.firstChild.nodeValue
146 url = row.childNodes[0].firstChild.getAttribute('href')
147 size = int(row.childNodes[1].firstChild.nodeValue)
148 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
149 channel.table[name] = ChannelTableEntry(
150 url=url, digest=digest, size=size)
151 v.ok()
152
153
154def digest_string(s: bytes) -> Digest16:
155 return Digest16(hashlib.sha256(s).hexdigest())
156
157
158def digest_file(filename: str) -> Digest16:
159 hasher = hashlib.sha256()
160 with open(filename, 'rb') as f:
161 # pylint: disable=cell-var-from-loop
162 for block in iter(lambda: f.read(4096), b''):
163 hasher.update(block)
164 return Digest16(hasher.hexdigest())
165
166
167def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
168 v.status('Converting digest to base16')
169 process = subprocess.run(
170 ['nix', 'to-base16', '--type', 'sha256', digest32], capture_output=True)
171 v.result(process.returncode == 0)
172 return Digest16(process.stdout.decode().strip())
173
174
175def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
176 v.status('Converting digest to base32')
177 process = subprocess.run(
178 ['nix', 'to-base32', '--type', 'sha256', digest16], capture_output=True)
179 v.result(process.returncode == 0)
180 return Digest32(process.stdout.decode().strip())
181
182
183def fetch_with_nix_prefetch_url(
184 v: Verification,
185 url: str,
186 digest: Digest16) -> str:
187 v.status('Fetching %s' % url)
188 process = subprocess.run(
189 ['nix-prefetch-url', '--print-path', url, digest], capture_output=True)
190 v.result(process.returncode == 0)
191 prefetch_digest, path, empty = process.stdout.decode().split('\n')
192 assert empty == ''
193 v.check("Verifying nix-prefetch-url's digest",
194 to_Digest16(v, Digest32(prefetch_digest)) == digest)
195 v.status("Verifying file digest")
196 file_digest = digest_file(path)
197 v.result(file_digest == digest)
198 return path
199
200
201def fetch_resources(v: Verification, channel: Channel) -> None:
202 for resource in ['git-revision', 'nixexprs.tar.xz']:
203 fields = channel.table[resource]
204 fields.absolute_url = urllib.parse.urljoin(
205 channel.forwarded_url, fields.url)
206 fields.file = fetch_with_nix_prefetch_url(
207 v, fields.absolute_url, fields.digest)
208 v.status('Verifying git commit on main page matches git commit in table')
209 v.result(
210 open(
211 channel.table['git-revision'].file).read(999) == channel.git_revision)
212
213
214def git_cachedir(git_repo: str) -> str:
215 # TODO: Consider using pyxdg to find this path.
216 return os.path.expanduser(
217 '~/.cache/nix-pin-channel/git/%s' %
218 digest_string(
219 git_repo.encode()))
220
221
222def verify_git_ancestry(v: Verification, channel: Channel) -> None:
223 cachedir = git_cachedir(channel.git_repo)
224 v.status('Verifying rev is an ancestor of ref')
225 process = subprocess.run(['git',
226 '-C',
227 cachedir,
228 'merge-base',
229 '--is-ancestor',
230 channel.git_revision,
231 channel.git_ref])
232 v.result(process.returncode == 0)
233
234 if hasattr(channel, 'old_git_revision'):
235 v.status(
236 'Verifying rev is an ancestor of previous rev %s' %
237 channel.old_git_revision)
238 process = subprocess.run(['git',
239 '-C',
240 cachedir,
241 'merge-base',
242 '--is-ancestor',
243 channel.old_git_revision,
244 channel.git_revision])
245 v.result(process.returncode == 0)
246
247
248def git_fetch(v: Verification, channel: Channel) -> None:
249 # It would be nice if we could share the nix git cache, but as of the time
250 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
251 # yet), and trying to straddle them both is too far into nix implementation
252 # details for my comfort. So we re-implement here half of nix.fetchGit.
253 # :(
254
255 cachedir = git_cachedir(channel.git_repo)
256 if not os.path.exists(cachedir):
257 v.status("Initializing git repo")
258 process = subprocess.run(
259 ['git', 'init', '--bare', cachedir])
260 v.result(process.returncode == 0)
261
262 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
263 # We don't use --force here because we want to abort and freak out if forced
264 # updates are happening.
265 process = subprocess.run(['git',
266 '-C',
267 cachedir,
268 'fetch',
269 channel.git_repo,
270 '%s:%s' % (channel.git_ref,
271 channel.git_ref)])
272 v.result(process.returncode == 0)
273
274 if hasattr(channel, 'git_revision'):
275 v.status('Verifying that fetch retrieved this rev')
276 process = subprocess.run(
277 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
278 v.result(process.returncode == 0)
279 else:
280 channel.git_revision = open(
281 os.path.join(
282 cachedir,
283 'refs',
284 'heads',
285 channel.git_ref)).read(999).strip()
286
287 verify_git_ancestry(v, channel)
288
289
290def ensure_git_rev_available(v: Verification, channel: Channel) -> None:
291 cachedir = git_cachedir(channel.git_repo)
292 if os.path.exists(cachedir):
293 v.status('Checking if we already have this rev:')
294 process = subprocess.run(
295 ['git', '-C', cachedir, 'cat-file', '-e', channel.git_revision])
296 if process.returncode == 0:
297 v.status('yes')
298 if process.returncode == 1:
299 v.status('no')
300 v.result(process.returncode == 0 or process.returncode == 1)
301 if process.returncode == 0:
302 verify_git_ancestry(v, channel)
303 return
304 git_fetch(v, channel)
305
306
307def compare_tarball_and_git(
308 v: Verification,
309 channel: Channel,
310 channel_contents: str,
311 git_contents: str) -> None:
312 v.status('Comparing channel tarball with git checkout')
313 match, mismatch, errors = compare(os.path.join(
314 channel_contents, channel.release_name), git_contents)
315 v.ok()
316 v.check('%d files match' % len(match), len(match) > 0)
317 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
318 expected_errors = [
319 '.git-revision',
320 '.version-suffix',
321 'nixpkgs',
322 'programs.sqlite',
323 'svn-revision']
324 benign_errors = []
325 for ee in expected_errors:
326 if ee in errors:
327 errors.remove(ee)
328 benign_errors.append(ee)
329 v.check(
330 '%d unexpected incomparable files' %
331 len(errors),
332 len(errors) == 0)
333 v.check(
334 '(%d of %d expected incomparable files)' %
335 (len(benign_errors),
336 len(expected_errors)),
337 len(benign_errors) == len(expected_errors))
338
339
340def extract_tarball(v: Verification, channel: Channel, dest: str) -> None:
341 v.status('Extracting tarball %s' %
342 channel.table['nixexprs.tar.xz'].file)
343 shutil.unpack_archive(
344 channel.table['nixexprs.tar.xz'].file,
345 dest)
346 v.ok()
347
348
349def git_checkout(v: Verification, channel: Channel, dest: str) -> None:
350 v.status('Checking out corresponding git revision')
351 git = subprocess.Popen(['git',
352 '-C',
353 git_cachedir(channel.git_repo),
354 'archive',
355 channel.git_revision],
356 stdout=subprocess.PIPE)
357 tar = subprocess.Popen(
358 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
359 git.stdout.close()
360 tar.wait()
361 git.wait()
362 v.result(git.returncode == 0 and tar.returncode == 0)
363
364
365def check_channel_metadata(
366 v: Verification,
367 channel: Channel,
368 channel_contents: str) -> None:
369 v.status('Verifying git commit in channel tarball')
370 v.result(
371 open(
372 os.path.join(
373 channel_contents,
374 channel.release_name,
375 '.git-revision')).read(999) == channel.git_revision)
376
377 v.status(
378 'Verifying version-suffix is a suffix of release name %s:' %
379 channel.release_name)
380 version_suffix = open(
381 os.path.join(
382 channel_contents,
383 channel.release_name,
384 '.version-suffix')).read(999)
385 v.status(version_suffix)
386 v.result(channel.release_name.endswith(version_suffix))
387
388
389def check_channel_contents(v: Verification, channel: Channel) -> None:
390 with tempfile.TemporaryDirectory() as channel_contents, \
391 tempfile.TemporaryDirectory() as git_contents:
392
393 extract_tarball(v, channel, channel_contents)
394 check_channel_metadata(v, channel, channel_contents)
395
396 git_checkout(v, channel, git_contents)
397
398 compare_tarball_and_git(v, channel, channel_contents, git_contents)
399
400 v.status('Removing temporary directories')
401 v.ok()
402
403
404def pin_channel(v: Verification, channel: Channel) -> None:
405 fetch(v, channel)
406 parse_channel(v, channel)
407 fetch_resources(v, channel)
408 ensure_git_rev_available(v, channel)
409 check_channel_contents(v, channel)
410
411
412def git_revision_name(v: Verification, channel: Channel) -> str:
413 v.status('Getting commit date')
414 process = subprocess.run(['git',
415 '-C',
416 git_cachedir(channel.git_repo),
417 'lo',
418 '-n1',
419 '--format=%ct-%h',
420 '--abbrev=11',
421 channel.git_revision],
422 capture_output=True)
423 v.result(process.returncode == 0 and process.stdout != '')
424 return '%s-%s' % (os.path.basename(channel.git_repo),
425 process.stdout.decode().strip())
426
427
428def make_channel(conf: configparser.SectionProxy) -> Channel:
429 channel = Channel(**dict(conf.items()))
430 if hasattr(channel, 'git_revision'):
431 channel.old_git_revision = channel.git_revision
432 del channel.git_revision
433 return channel
434
435
436def pin(args: argparse.Namespace) -> None:
437 v = Verification()
438 config = configparser.ConfigParser()
439 config.read_file(open(args.channels_file), args.channels_file)
440 for section in config.sections():
441 channel = make_channel(config[section])
442 if 'channel_url' in config[section]:
443 pin_channel(v, channel)
444 config[section]['name'] = channel.release_name
445 config[section]['tarball_url'] = channel.table['nixexprs.tar.xz'].absolute_url
446 config[section]['tarball_sha256'] = channel.table['nixexprs.tar.xz'].digest
447 else:
448 git_fetch(v, channel)
449 config[section]['name'] = git_revision_name(v, channel)
450 config[section]['git_revision'] = channel.git_revision
451
452 with open(args.channels_file, 'w') as configfile:
453 config.write(configfile)
454
455
456def main() -> None:
457 parser = argparse.ArgumentParser(prog='pinch')
458 subparsers = parser.add_subparsers(dest='mode', required=True)
459 parser_pin = subparsers.add_parser('pin')
460 parser_pin.add_argument('channels_file', type=str)
461 parser_pin.set_defaults(func=pin)
462 args = parser.parse_args()
463 args.func(args)
464
465
466main()