]> git.scottworley.com Git - pinch/blob - pinch.py
e00bfd8cf0c43b1b2dda67d770668d5ebaf44041
[pinch] / pinch.py
1 import filecmp
2 import functools
3 import hashlib
4 import operator
5 import os
6 import os.path
7 import shutil
8 import subprocess
9 import tempfile
10 import types
11 import urllib.parse
12 import urllib.request
13 import xml.dom.minidom
14
15 from typing import (
16 Dict,
17 Iterable,
18 List,
19 NewType,
20 Tuple,
21 )
22
23 Digest16 = NewType('Digest16', str)
24 Digest32 = NewType('Digest32', str)
25
26
27 class ChannelTableEntry(types.SimpleNamespace):
28 digest: Digest16
29 file: str
30 size: int
31 url: str
32
33
34 class Channel(types.SimpleNamespace):
35 channel_html: bytes
36 forwarded_url: str
37 git_cachedir: str
38 git_ref: str
39 git_repo: str
40 git_revision: str
41 release_name: str
42 table: Dict[str, ChannelTableEntry]
43 url: str
44
45
46 class VerificationError(Exception):
47 pass
48
49
50 class Verification:
51
52 def __init__(self) -> None:
53 self.line_length = 0
54
55 def status(self, s: str) -> None:
56 print(s, end=' ', flush=True)
57 self.line_length += 1 + len(s) # Unicode??
58
59 @staticmethod
60 def _color(s: str, c: int) -> str:
61 return '\033[%2dm%s\033[00m' % (c, s)
62
63 def result(self, r: bool) -> None:
64 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
65 length = len(message)
66 cols = shutil.get_terminal_size().columns
67 pad = (cols - (self.line_length + length)) % cols
68 print(' ' * pad + self._color(message, color))
69 self.line_length = 0
70 if not r:
71 raise VerificationError()
72
73 def check(self, s: str, r: bool) -> None:
74 self.status(s)
75 self.result(r)
76
77 def ok(self) -> None:
78 self.result(True)
79
80
81 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
82
83 def throw(error: OSError) -> None:
84 raise error
85
86 def join(x: str, y: str) -> str:
87 return y if x == '.' else os.path.join(x, y)
88
89 def recursive_files(d: str) -> Iterable[str]:
90 all_files: List[str] = []
91 for path, dirs, files in os.walk(d, onerror=throw):
92 rel = os.path.relpath(path, start=d)
93 all_files.extend(join(rel, f) for f in files)
94 for dir_or_link in dirs:
95 if os.path.islink(join(path, dir_or_link)):
96 all_files.append(join(rel, dir_or_link))
97 return all_files
98
99 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
100 return (f for f in files if not f.startswith('.git/'))
101
102 files = functools.reduce(
103 operator.or_, (set(
104 exclude_dot_git(
105 recursive_files(x))) for x in [a, b]))
106 return filecmp.cmpfiles(a, b, files, shallow=False)
107
108
109 def fetch(v: Verification, channel: Channel) -> None:
110 v.status('Fetching channel')
111 request = urllib.request.urlopen(channel.url, timeout=10)
112 channel.channel_html = request.read()
113 channel.forwarded_url = request.geturl()
114 v.result(request.status == 200)
115 v.check('Got forwarded', channel.url != channel.forwarded_url)
116
117
118 def parse_channel(v: Verification, channel: Channel) -> None:
119 v.status('Parsing channel description as XML')
120 d = xml.dom.minidom.parseString(channel.channel_html)
121 v.ok()
122
123 v.status('Extracting release name:')
124 title_name = d.getElementsByTagName(
125 'title')[0].firstChild.nodeValue.split()[2]
126 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
127 v.status(title_name)
128 v.result(title_name == h1_name)
129 channel.release_name = title_name
130
131 v.status('Extracting git commit:')
132 git_commit_node = d.getElementsByTagName('tt')[0]
133 channel.git_commit = git_commit_node.firstChild.nodeValue
134 v.status(channel.git_commit)
135 v.ok()
136 v.status('Verifying git commit label')
137 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
138
139 v.status('Parsing table')
140 channel.table = {}
141 for row in d.getElementsByTagName('tr')[1:]:
142 name = row.childNodes[0].firstChild.firstChild.nodeValue
143 url = row.childNodes[0].firstChild.getAttribute('href')
144 size = int(row.childNodes[1].firstChild.nodeValue)
145 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
146 channel.table[name] = ChannelTableEntry(
147 url=url, digest=digest, size=size)
148 v.ok()
149
150
151 def digest_string(s: bytes) -> Digest16:
152 return Digest16(hashlib.sha256(s).hexdigest())
153
154
155 def digest_file(filename: str) -> Digest16:
156 hasher = hashlib.sha256()
157 with open(filename, 'rb') as f:
158 # pylint: disable=cell-var-from-loop
159 for block in iter(lambda: f.read(4096), b''):
160 hasher.update(block)
161 return Digest16(hasher.hexdigest())
162
163
164 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
165 v.status('Converting digest to base16')
166 process = subprocess.run(
167 ['nix', 'to-base16', '--type', 'sha256', digest32], capture_output=True)
168 v.result(process.returncode == 0)
169 return Digest16(process.stdout.decode().strip())
170
171
172 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
173 v.status('Converting digest to base32')
174 process = subprocess.run(
175 ['nix', 'to-base32', '--type', 'sha256', digest16], capture_output=True)
176 v.result(process.returncode == 0)
177 return Digest32(process.stdout.decode().strip())
178
179
180 def fetch_with_nix_prefetch_url(
181 v: Verification,
182 url: str,
183 digest: Digest16) -> str:
184 v.status('Fetching %s' % url)
185 process = subprocess.run(
186 ['nix-prefetch-url', '--print-path', url, digest], capture_output=True)
187 v.result(process.returncode == 0)
188 prefetch_digest, path, empty = process.stdout.decode().split('\n')
189 assert empty == ''
190 v.check("Verifying nix-prefetch-url's digest",
191 to_Digest16(v, Digest32(prefetch_digest)) == digest)
192 v.status("Verifying file digest")
193 file_digest = digest_file(path)
194 v.result(file_digest == digest)
195 return path
196
197
198 def fetch_resources(v: Verification, channel: Channel) -> None:
199 for resource in ['git-revision', 'nixexprs.tar.xz']:
200 fields = channel.table[resource]
201 url = urllib.parse.urljoin(channel.forwarded_url, fields.url)
202 fields.file = fetch_with_nix_prefetch_url(v, url, fields.digest)
203 v.status('Verifying git commit on main page matches git commit in table')
204 v.result(
205 open(
206 channel.table['git-revision'].file).read(999) == channel.git_commit)
207
208
209 def git_fetch(v: Verification, channel: Channel) -> None:
210 # It would be nice if we could share the nix git cache, but as of the time
211 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
212 # yet), and trying to straddle them both is too far into nix implementation
213 # details for my comfort. So we re-implement here half of nix.fetchGit.
214 # :(
215
216 # TODO: Consider using pyxdg to find this path.
217 channel.git_cachedir = os.path.expanduser(
218 '~/.cache/nix-pin-channel/git/%s' %
219 digest_string(
220 channel.url.encode()))
221 if not os.path.exists(channel.git_cachedir):
222 v.status("Initializing git repo")
223 process = subprocess.run(
224 ['git', 'init', '--bare', channel.git_cachedir])
225 v.result(process.returncode == 0)
226
227 v.status('Checking if we already have this rev:')
228 process = subprocess.run(
229 ['git', '-C', channel.git_cachedir, 'cat-file', '-e', channel.git_commit])
230 if process.returncode == 0:
231 v.status('yes')
232 if process.returncode == 1:
233 v.status('no')
234 v.result(process.returncode == 0 or process.returncode == 1)
235 if process.returncode == 1:
236 v.status('Fetching ref "%s"' % channel.git_ref)
237 # We don't use --force here because we want to abort and freak out if forced
238 # updates are happening.
239 process = subprocess.run(['git',
240 '-C',
241 channel.git_cachedir,
242 'fetch',
243 channel.git_repo,
244 '%s:%s' % (channel.git_ref,
245 channel.git_ref)])
246 v.result(process.returncode == 0)
247 v.status('Verifying that fetch retrieved this rev')
248 process = subprocess.run(
249 ['git', '-C', channel.git_cachedir, 'cat-file', '-e', channel.git_commit])
250 v.result(process.returncode == 0)
251
252 v.status('Verifying rev is an ancestor of ref')
253 process = subprocess.run(['git',
254 '-C',
255 channel.git_cachedir,
256 'merge-base',
257 '--is-ancestor',
258 channel.git_commit,
259 channel.git_ref])
260 v.result(process.returncode == 0)
261
262
263 def check_channel_contents(v: Verification, channel: Channel) -> None:
264 with tempfile.TemporaryDirectory() as channel_contents, \
265 tempfile.TemporaryDirectory() as git_contents:
266 v.status('Extracting tarball %s' %
267 channel.table['nixexprs.tar.xz'].file)
268 shutil.unpack_archive(
269 channel.table['nixexprs.tar.xz'].file,
270 channel_contents)
271 v.ok()
272 v.status('Checking out corresponding git revision')
273 git = subprocess.Popen(['git',
274 '-C',
275 channel.git_cachedir,
276 'archive',
277 channel.git_commit],
278 stdout=subprocess.PIPE)
279 tar = subprocess.Popen(
280 ['tar', 'x', '-C', git_contents, '-f', '-'], stdin=git.stdout)
281 git.stdout.close()
282 tar.wait()
283 git.wait()
284 v.result(git.returncode == 0 and tar.returncode == 0)
285 v.status('Comparing channel tarball with git checkout')
286 match, mismatch, errors = compare(os.path.join(
287 channel_contents, channel.release_name), git_contents)
288 v.ok()
289 v.check('%d files match' % len(match), len(match) > 0)
290 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
291 expected_errors = [
292 '.git-revision',
293 '.version-suffix',
294 'nixpkgs',
295 'programs.sqlite',
296 'svn-revision']
297 benign_errors = []
298 for ee in expected_errors:
299 if ee in errors:
300 errors.remove(ee)
301 benign_errors.append(ee)
302 v.check(
303 '%d unexpected incomparable files' %
304 len(errors),
305 len(errors) == 0)
306 v.check(
307 '(%d of %d expected incomparable files)' %
308 (len(benign_errors),
309 len(expected_errors)),
310 len(benign_errors) == len(expected_errors))
311 v.status('Removing temporary directories')
312 v.ok()
313
314
315 def main() -> None:
316 v = Verification()
317 channel = Channel(url='https://channels.nixos.org/nixos-20.03',
318 git_repo='https://github.com/NixOS/nixpkgs.git',
319 git_ref='nixos-20.03')
320 fetch(v, channel)
321 parse_channel(v, channel)
322 fetch_resources(v, channel)
323 git_fetch(v, channel)
324 check_channel_contents(v, channel)
325 print(channel)
326
327
328 main()