]> git.scottworley.com Git - pinch/blob - pinch.py
Use SimpleNamespace for typed records
[pinch] / pinch.py
1 import filecmp
2 import functools
3 import hashlib
4 import operator
5 import os
6 import os.path
7 import shutil
8 import tempfile
9 import types
10 import urllib.parse
11 import urllib.request
12 import xml.dom.minidom
13
14 from typing import (
15 Dict,
16 Iterable,
17 List,
18 Sequence,
19 Tuple,
20 )
21
22
23 class InfoTableEntry(types.SimpleNamespace):
24 content: bytes
25 digest: str
26 file: str
27 size: int
28 url: str
29
30
31 class Info(types.SimpleNamespace):
32 channel_html: bytes
33 forwarded_url: str
34 git_revision: str
35 table: Dict[str, InfoTableEntry]
36 url: str
37
38
39 class VerificationError(Exception):
40 pass
41
42
43 class Verification:
44
45 def __init__(self) -> None:
46 self.line_length = 0
47
48 def status(self, s: str) -> None:
49 print(s, end=' ', flush=True)
50 self.line_length += 1 + len(s) # Unicode??
51
52 @staticmethod
53 def _color(s: str, c: int) -> str:
54 return '\033[%2dm%s\033[00m' % (c, s)
55
56 def result(self, r: bool) -> None:
57 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
58 length = len(message)
59 cols = shutil.get_terminal_size().columns
60 pad = (cols - (self.line_length + length)) % cols
61 print(' ' * pad + self._color(message, color))
62 self.line_length = 0
63 if not r:
64 raise VerificationError()
65
66 def check(self, s: str, r: bool) -> None:
67 self.status(s)
68 self.result(r)
69
70 def ok(self) -> None:
71 self.result(True)
72
73
74 def compare(a: str,
75 b: str) -> Tuple[Sequence[str],
76 Sequence[str],
77 Sequence[str]]:
78
79 def throw(error: OSError) -> None:
80 raise error
81
82 def join(x: str, y: str) -> str:
83 return y if x == '.' else os.path.join(x, y)
84
85 def recursive_files(d: str) -> Iterable[str]:
86 all_files: List[str] = []
87 for path, dirs, files in os.walk(d, onerror=throw):
88 rel = os.path.relpath(path, start=d)
89 all_files.extend(join(rel, f) for f in files)
90 for dir_or_link in dirs:
91 if os.path.islink(join(path, dir_or_link)):
92 all_files.append(join(rel, dir_or_link))
93 return all_files
94
95 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
96 return (f for f in files if not f.startswith('.git/'))
97
98 files = functools.reduce(
99 operator.or_, (set(
100 exclude_dot_git(
101 recursive_files(x))) for x in [a, b]))
102 return filecmp.cmpfiles(a, b, files, shallow=False)
103
104
105 def fetch(v: Verification, channel_url: str) -> Info:
106 info = Info()
107 info.url = channel_url
108 v.status('Fetching channel')
109 request = urllib.request.urlopen(
110 'https://channels.nixos.org/nixos-20.03', timeout=10)
111 info.channel_html = request.read()
112 info.forwarded_url = request.geturl()
113 v.result(request.status == 200)
114 v.check('Got forwarded', info.url != info.forwarded_url)
115 return info
116
117
118 def parse(v: Verification, info: Info) -> None:
119 v.status('Parsing channel description as XML')
120 d = xml.dom.minidom.parseString(info.channel_html)
121 v.ok()
122
123 v.status('Extracting git commit')
124 git_commit_node = d.getElementsByTagName('tt')[0]
125 info.git_commit = git_commit_node.firstChild.nodeValue
126 v.ok()
127 v.status('Verifying git commit label')
128 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
129
130 v.status('Parsing table')
131 info.table = {}
132 for row in d.getElementsByTagName('tr')[1:]:
133 name = row.childNodes[0].firstChild.firstChild.nodeValue
134 url = row.childNodes[0].firstChild.getAttribute('href')
135 size = int(row.childNodes[1].firstChild.nodeValue)
136 digest = row.childNodes[2].firstChild.firstChild.nodeValue
137 info.table[name] = InfoTableEntry(url=url, digest=digest, size=size)
138 v.ok()
139
140
141 def fetch_resources(v: Verification, info: Info) -> None:
142
143 for resource in ['git-revision', 'nixexprs.tar.xz']:
144 fields = info.table[resource]
145 v.status('Fetching resource "%s"' % resource)
146 url = urllib.parse.urljoin(info.forwarded_url, fields.url)
147 request = urllib.request.urlopen(url, timeout=10)
148 if fields.size < 4096:
149 fields.content = request.read()
150 else:
151 with tempfile.NamedTemporaryFile(suffix='.nixexprs.tar.xz', delete=False) as tmp_file:
152 shutil.copyfileobj(request, tmp_file)
153 fields.file = tmp_file.name
154 v.result(request.status == 200)
155 v.status('Verifying digest for "%s"' % resource)
156 if fields.size < 4096:
157 actual_hash = hashlib.sha256(fields.content).hexdigest()
158 else:
159 hasher = hashlib.sha256()
160 with open(fields.file, 'rb') as f:
161 # pylint: disable=cell-var-from-loop
162 for block in iter(lambda: f.read(4096), b''):
163 hasher.update(block)
164 actual_hash = hasher.hexdigest()
165 v.result(actual_hash == fields.digest)
166 v.check('Verifying git commit on main page matches git commit in table',
167 info.table['git-revision'].content.decode() == info.git_commit)
168
169
170 def extract_channel(v: Verification, info: Info) -> None:
171 with tempfile.TemporaryDirectory() as d:
172 v.status('Extracting nixexprs.tar.xz')
173 shutil.unpack_archive(info.table['nixexprs.tar.xz'].file, d)
174 v.ok()
175 v.status('Removing temporary directory')
176 v.ok()
177
178
179 def main() -> None:
180 v = Verification()
181 info = fetch(v, 'https://channels.nixos.org/nixos-20.03')
182 parse(v, info)
183 fetch_resources(v, info)
184 extract_channel(v, info)
185 print(info)
186
187
188 main()