]> git.scottworley.com Git - pinch/blob - pinch.py
f974d1c0adc3982dbc08e469ed0eed4c4f1b682b
[pinch] / pinch.py
1 import filecmp
2 import functools
3 import hashlib
4 import operator
5 import os
6 import os.path
7 import shutil
8 import tempfile
9 import urllib.parse
10 import urllib.request
11 import xml.dom.minidom
12
13 from typing import (
14 Any,
15 Dict,
16 Iterable,
17 List,
18 Sequence,
19 Tuple,
20 )
21
22
23 class VerificationError(Exception):
24 pass
25
26
27 class Verification:
28
29 def __init__(self) -> None:
30 self.line_length = 0
31
32 def status(self, s: str) -> None:
33 print(s, end=' ', flush=True)
34 self.line_length += 1 + len(s) # Unicode??
35
36 @staticmethod
37 def _color(s: str, c: int) -> str:
38 return '\033[%2dm%s\033[00m' % (c, s)
39
40 def result(self, r: bool) -> None:
41 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
42 length = len(message)
43 cols = shutil.get_terminal_size().columns
44 pad = (cols - (self.line_length + length)) % cols
45 print(' ' * pad + self._color(message, color))
46 self.line_length = 0
47 if not r:
48 raise VerificationError()
49
50 def check(self, s: str, r: bool) -> None:
51 self.status(s)
52 self.result(r)
53
54 def ok(self) -> None:
55 self.result(True)
56
57
58 def compare(a: str,
59 b: str) -> Tuple[Sequence[str],
60 Sequence[str],
61 Sequence[str]]:
62
63 def throw(error: OSError) -> None:
64 raise error
65
66 def join(x: str, y: str) -> str:
67 return y if x == '.' else os.path.join(x, y)
68
69 def recursive_files(d: str) -> Iterable[str]:
70 all_files: List[str] = []
71 for path, dirs, files in os.walk(d, onerror=throw):
72 rel = os.path.relpath(path, start=d)
73 all_files.extend(join(rel, f) for f in files)
74 for dir_or_link in dirs:
75 if os.path.islink(join(path, dir_or_link)):
76 all_files.append(join(rel, dir_or_link))
77 return all_files
78
79 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
80 return (f for f in files if not f.startswith('.git/'))
81
82 files = functools.reduce(
83 operator.or_, (set(
84 exclude_dot_git(
85 recursive_files(x))) for x in [a, b]))
86 return filecmp.cmpfiles(a, b, files, shallow=False)
87
88
89 def fetch(v: Verification, channel_url: str) -> Dict[str, Any]:
90 info: Dict[str, Any] = {'url': channel_url}
91 v.status('Fetching channel')
92 request = urllib.request.urlopen(
93 'https://channels.nixos.org/nixos-20.03', timeout=10)
94 info['channel_html'] = request.read()
95 info['forwarded_url'] = request.geturl()
96 v.result(request.status == 200)
97 v.check('Got forwarded', info['url'] != info['forwarded_url'])
98 return info
99
100
101 def parse(v: Verification, info: Dict[str, Any]) -> None:
102 v.status('Parsing channel description as XML')
103 d = xml.dom.minidom.parseString(info['channel_html'])
104 v.ok()
105
106 v.status('Extracting git commit')
107 git_commit_node = d.getElementsByTagName('tt')[0]
108 info['git_commit'] = git_commit_node.firstChild.nodeValue
109 v.ok()
110 v.status('Verifying git commit label')
111 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
112
113 v.status('Parsing table')
114 info['table'] = {}
115 for row in d.getElementsByTagName('tr')[1:]:
116 name = row.childNodes[0].firstChild.firstChild.nodeValue
117 url = row.childNodes[0].firstChild.getAttribute('href')
118 size = int(row.childNodes[1].firstChild.nodeValue)
119 digest = row.childNodes[2].firstChild.firstChild.nodeValue
120 info['table'][name] = {'url': url, 'digest': digest, 'size': size}
121 v.ok()
122
123
124 def fetch_resources(v: Verification, info: Dict[str, Any]) -> None:
125
126 for resource in ['git-revision', 'nixexprs.tar.xz']:
127 fields = info['table'][resource]
128 v.status('Fetching resource "%s"' % resource)
129 url = urllib.parse.urljoin(info['forwarded_url'], fields['url'])
130 request = urllib.request.urlopen(url, timeout=10)
131 if fields['size'] < 4096:
132 fields['content'] = request.read()
133 else:
134 with tempfile.NamedTemporaryFile(suffix='.nixexprs.tar.xz', delete=False) as tmp_file:
135 shutil.copyfileobj(request, tmp_file)
136 fields['file'] = tmp_file.name
137 v.result(request.status == 200)
138 v.status('Verifying digest for "%s"' % resource)
139 if fields['size'] < 4096:
140 actual_hash = hashlib.sha256(fields['content']).hexdigest()
141 else:
142 hasher = hashlib.sha256()
143 with open(fields['file'], 'rb') as f:
144 # pylint: disable=cell-var-from-loop
145 for block in iter(lambda: f.read(4096), b''):
146 hasher.update(block)
147 actual_hash = hasher.hexdigest()
148 v.result(actual_hash == fields['digest'])
149 v.check('Verifying git commit on main page matches git commit in table',
150 info['table']['git-revision']['content'].decode() == info['git_commit'])
151
152
153 def extract_channel(v: Verification, info: Dict[str, Any]) -> None:
154 with tempfile.TemporaryDirectory() as d:
155 v.status('Extracting nixexprs.tar.xz')
156 shutil.unpack_archive(info['table']['nixexprs.tar.xz']['file'], d)
157 v.ok()
158 v.status('Removing temporary directory')
159 v.ok()
160
161
162 def main() -> None:
163 v = Verification()
164 info = fetch(v, 'https://channels.nixos.org/nixos-20.03')
165 parse(v, info)
166 fetch_resources(v, info)
167 extract_channel(v, info)
168 print(info)
169
170
171 main()