+import filecmp
+import functools
+import hashlib
+import operator
+import os
+import os.path
+import shutil
+import tempfile
+import urllib.parse
+import urllib.request
+import xml.dom.minidom
+
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Sequence,
+ Tuple,
+)
+
+
+class VerificationError(Exception):
+ pass
+
+
+class Verification:
+
+ def __init__(self) -> None:
+ self.line_length = 0
+
+ def status(self, s: str) -> None:
+ print(s, end=' ', flush=True)
+ self.line_length += 1 + len(s) # Unicode??
+
+ @staticmethod
+ def _color(s: str, c: int) -> str:
+ return '\033[%2dm%s\033[00m' % (c, s)
+
+ def result(self, r: bool) -> None:
+ message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
+ length = len(message)
+ cols = shutil.get_terminal_size().columns
+ pad = (cols - (self.line_length + length)) % cols
+ print(' ' * pad + self._color(message, color))
+ self.line_length = 0
+ if not r:
+ raise VerificationError()
+
+ def check(self, s: str, r: bool) -> None:
+ self.status(s)
+ self.result(r)
+
+ def ok(self) -> None:
+ self.result(True)
+
+
+def compare(a: str,
+ b: str) -> Tuple[Sequence[str],
+ Sequence[str],
+ Sequence[str]]:
+
+ def throw(error: OSError) -> None:
+ raise error
+
+ def join(x: str, y: str) -> str:
+ return y if x == '.' else os.path.join(x, y)
+
+ def recursive_files(d: str) -> Iterable[str]:
+ all_files: List[str] = []
+ for path, dirs, files in os.walk(d, onerror=throw):
+ rel = os.path.relpath(path, start=d)
+ all_files.extend(join(rel, f) for f in files)
+ for dir_or_link in dirs:
+ if os.path.islink(join(path, dir_or_link)):
+ all_files.append(join(rel, dir_or_link))
+ return all_files
+
+ def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
+ return (f for f in files if not f.startswith('.git/'))
+
+ files = functools.reduce(
+ operator.or_, (set(
+ exclude_dot_git(
+ recursive_files(x))) for x in [a, b]))
+ return filecmp.cmpfiles(a, b, files, shallow=False)
+
+
+def fetch(v: Verification, channel_url: str) -> Dict[str, Any]:
+ info: Dict[str, Any] = {'url': channel_url}
+ v.status('Fetching channel')
+ request = urllib.request.urlopen(
+ 'https://channels.nixos.org/nixos-20.03', timeout=10)
+ info['channel_html'] = request.read()
+ info['forwarded_url'] = request.geturl()
+ v.result(request.status == 200)
+ v.check('Got forwarded', info['url'] != info['forwarded_url'])
+ return info
+
+
+def parse(v: Verification, info: Dict[str, Any]) -> None:
+ v.status('Parsing channel description as XML')
+ d = xml.dom.minidom.parseString(info['channel_html'])
+ v.ok()
+
+ v.status('Extracting git commit')
+ git_commit_node = d.getElementsByTagName('tt')[0]
+ info['git_commit'] = git_commit_node.firstChild.nodeValue
+ v.ok()
+ v.status('Verifying git commit label')
+ v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
+
+ v.status('Parsing table')
+ info['table'] = {}
+ for row in d.getElementsByTagName('tr')[1:]:
+ name = row.childNodes[0].firstChild.firstChild.nodeValue
+ url = row.childNodes[0].firstChild.getAttribute('href')
+ size = int(row.childNodes[1].firstChild.nodeValue)
+ digest = row.childNodes[2].firstChild.firstChild.nodeValue
+ info['table'][name] = {'url': url, 'digest': digest, 'size': size}
+ v.ok()
+
+
+def fetch_resources(v: Verification, info: Dict[str, Any]) -> None:
+
+ for resource in ['git-revision', 'nixexprs.tar.xz']:
+ fields = info['table'][resource]
+ v.status('Fetching resource "%s"' % resource)
+ url = urllib.parse.urljoin(info['forwarded_url'], fields['url'])
+ request = urllib.request.urlopen(url, timeout=10)
+ if fields['size'] < 4096:
+ fields['content'] = request.read()
+ else:
+ with tempfile.NamedTemporaryFile(suffix='.nixexprs.tar.xz', delete=False) as tmp_file:
+ shutil.copyfileobj(request, tmp_file)
+ fields['file'] = tmp_file.name
+ v.result(request.status == 200)
+ v.status('Verifying digest for "%s"' % resource)
+ if fields['size'] < 4096:
+ actual_hash = hashlib.sha256(fields['content']).hexdigest()
+ else:
+ hasher = hashlib.sha256()
+ with open(fields['file'], 'rb') as f:
+ # pylint: disable=cell-var-from-loop
+ for block in iter(lambda: f.read(4096), b''):
+ hasher.update(block)
+ actual_hash = hasher.hexdigest()
+ v.result(actual_hash == fields['digest'])
+ v.check('Verifying git commit on main page matches git commit in table',
+ info['table']['git-revision']['content'].decode() == info['git_commit'])
+
+
+def extract_channel(v: Verification, info: Dict[str, Any]) -> None:
+ with tempfile.TemporaryDirectory() as d:
+ v.status('Extracting nixexprs.tar.xz')
+ shutil.unpack_archive(info['table']['nixexprs.tar.xz']['file'], d)
+ v.ok()
+ v.status('Removing temporary directory')
+ v.ok()
+
+
+def main() -> None:
+ v = Verification()
+ info = fetch(v, 'https://channels.nixos.org/nixos-20.03')
+ parse(v, info)
+ fetch_resources(v, info)
+ extract_channel(v, info)
+ print(info)
+
+
+main()