import argparse import configparser import filecmp import functools import hashlib import operator import os import os.path import shutil import subprocess import tempfile import types import urllib.parse import urllib.request import xml.dom.minidom from typing import ( Dict, Iterable, List, NewType, Tuple, ) Digest16 = NewType('Digest16', str) Digest32 = NewType('Digest32', str) class ChannelTableEntry(types.SimpleNamespace): absolute_url: str digest: Digest16 file: str size: int url: str class Channel(types.SimpleNamespace): channel_html: bytes channel_url: str forwarded_url: str git_cachedir: str git_ref: str git_repo: str git_revision: str old_git_revision: str release_name: str table: Dict[str, ChannelTableEntry] class VerificationError(Exception): pass class Verification: def __init__(self) -> None: self.line_length = 0 def status(self, s: str) -> None: print(s, end=' ', flush=True) self.line_length += 1 + len(s) # Unicode?? @staticmethod def _color(s: str, c: int) -> str: return '\033[%2dm%s\033[00m' % (c, s) def result(self, r: bool) -> None: message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r] length = len(message) cols = shutil.get_terminal_size().columns pad = (cols - (self.line_length + length)) % cols print(' ' * pad + self._color(message, color)) self.line_length = 0 if not r: raise VerificationError() def check(self, s: str, r: bool) -> None: self.status(s) self.result(r) def ok(self) -> None: self.result(True) def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]: def throw(error: OSError) -> None: raise error def join(x: str, y: str) -> str: return y if x == '.' else os.path.join(x, y) def recursive_files(d: str) -> Iterable[str]: all_files: List[str] = [] for path, dirs, files in os.walk(d, onerror=throw): rel = os.path.relpath(path, start=d) all_files.extend(join(rel, f) for f in files) for dir_or_link in dirs: if os.path.islink(join(path, dir_or_link)): all_files.append(join(rel, dir_or_link)) return all_files def exclude_dot_git(files: Iterable[str]) -> Iterable[str]: return (f for f in files if not f.startswith('.git/')) files = functools.reduce( operator.or_, (set( exclude_dot_git( recursive_files(x))) for x in [a, b])) return filecmp.cmpfiles(a, b, files, shallow=False) def fetch(v: Verification, channel: Channel) -> None: v.status('Fetching channel') request = urllib.request.urlopen(channel.channel_url, timeout=10) channel.channel_html = request.read() channel.forwarded_url = request.geturl() v.result(request.status == 200) v.check('Got forwarded', channel.channel_url != channel.forwarded_url) def parse_channel(v: Verification, channel: Channel) -> None: v.status('Parsing channel description as XML') d = xml.dom.minidom.parseString(channel.channel_html) v.ok() v.status('Extracting release name:') title_name = d.getElementsByTagName( 'title')[0].firstChild.nodeValue.split()[2] h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2] v.status(title_name) v.result(title_name == h1_name) channel.release_name = title_name v.status('Extracting git commit:') git_commit_node = d.getElementsByTagName('tt')[0] channel.git_revision = git_commit_node.firstChild.nodeValue v.status(channel.git_revision) v.ok() v.status('Verifying git commit label') v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ') v.status('Parsing table') channel.table = {} for row in d.getElementsByTagName('tr')[1:]: name = row.childNodes[0].firstChild.firstChild.nodeValue url = row.childNodes[0].firstChild.getAttribute('href') size = int(row.childNodes[1].firstChild.nodeValue) digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue) channel.table[name] = ChannelTableEntry( url=url, digest=digest, size=size) v.ok() def digest_string(s: bytes) -> Digest16: return Digest16(hashlib.sha256(s).hexdigest()) def digest_file(filename: str) -> Digest16: hasher = hashlib.sha256() with open(filename, 'rb') as f: # pylint: disable=cell-var-from-loop for block in iter(lambda: f.read(4096), b''): hasher.update(block) return Digest16(hasher.hexdigest()) def to_Digest16(v: Verification, digest32: Digest32) -> Digest16: v.status('Converting digest to base16') process = subprocess.run( ['nix', 'to-base16', '--type', 'sha256', digest32], capture_output=True) v.result(process.returncode == 0) return Digest16(process.stdout.decode().strip()) def to_Digest32(v: Verification, digest16: Digest16) -> Digest32: v.status('Converting digest to base32') process = subprocess.run( ['nix', 'to-base32', '--type', 'sha256', digest16], capture_output=True) v.result(process.returncode == 0) return Digest32(process.stdout.decode().strip()) def fetch_with_nix_prefetch_url( v: Verification, url: str, digest: Digest16) -> str: v.status('Fetching %s' % url) process = subprocess.run( ['nix-prefetch-url', '--print-path', url, digest], capture_output=True) v.result(process.returncode == 0) prefetch_digest, path, empty = process.stdout.decode().split('\n') assert empty == '' v.check("Verifying nix-prefetch-url's digest", to_Digest16(v, Digest32(prefetch_digest)) == digest) v.status("Verifying file digest") file_digest = digest_file(path) v.result(file_digest == digest) return path def fetch_resources(v: Verification, channel: Channel) -> None: for resource in ['git-revision', 'nixexprs.tar.xz']: fields = channel.table[resource] fields.absolute_url = urllib.parse.urljoin( channel.forwarded_url, fields.url) fields.file = fetch_with_nix_prefetch_url( v, fields.absolute_url, fields.digest) v.status('Verifying git commit on main page matches git commit in table') v.result( open( channel.table['git-revision'].file).read(999) == channel.git_revision) def git_fetch(v: Verification, channel: Channel) -> None: # It would be nice if we could share the nix git cache, but as of the time # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready # yet), and trying to straddle them both is too far into nix implementation # details for my comfort. So we re-implement here half of nix.fetchGit. # :( # TODO: Consider using pyxdg to find this path. channel.git_cachedir = os.path.expanduser( '~/.cache/nix-pin-channel/git/%s' % digest_string( channel.git_repo.encode())) if not os.path.exists(channel.git_cachedir): v.status("Initializing git repo") process = subprocess.run( ['git', 'init', '--bare', channel.git_cachedir]) v.result(process.returncode == 0) have_rev = False if hasattr(channel, 'git_revision'): v.status('Checking if we already have this rev:') process = subprocess.run( ['git', '-C', channel.git_cachedir, 'cat-file', '-e', channel.git_revision]) if process.returncode == 0: v.status('yes') if process.returncode == 1: v.status('no') v.result(process.returncode == 0 or process.returncode == 1) have_rev = process.returncode == 0 if not have_rev: v.status( 'Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo)) # We don't use --force here because we want to abort and freak out if forced # updates are happening. process = subprocess.run(['git', '-C', channel.git_cachedir, 'fetch', channel.git_repo, '%s:%s' % (channel.git_ref, channel.git_ref)]) v.result(process.returncode == 0) if hasattr(channel, 'git_revision'): v.status('Verifying that fetch retrieved this rev') process = subprocess.run( ['git', '-C', channel.git_cachedir, 'cat-file', '-e', channel.git_revision]) v.result(process.returncode == 0) if not hasattr(channel, 'git_revision'): channel.git_revision = open( os.path.join( channel.git_cachedir, 'refs', 'heads', channel.git_ref)).read(999).strip() v.status('Verifying rev is an ancestor of ref') process = subprocess.run(['git', '-C', channel.git_cachedir, 'merge-base', '--is-ancestor', channel.git_revision, channel.git_ref]) v.result(process.returncode == 0) if hasattr(channel, 'old_git_revision'): v.status( 'Verifying rev is an ancestor of previous rev %s' % channel.old_git_revision) process = subprocess.run(['git', '-C', channel.git_cachedir, 'merge-base', '--is-ancestor', channel.old_git_revision, channel.git_revision]) v.result(process.returncode == 0) def compare_tarball_and_git( v: Verification, channel: Channel, channel_contents: str, git_contents: str) -> None: v.status('Comparing channel tarball with git checkout') match, mismatch, errors = compare(os.path.join( channel_contents, channel.release_name), git_contents) v.ok() v.check('%d files match' % len(match), len(match) > 0) v.check('%d files differ' % len(mismatch), len(mismatch) == 0) expected_errors = [ '.git-revision', '.version-suffix', 'nixpkgs', 'programs.sqlite', 'svn-revision'] benign_errors = [] for ee in expected_errors: if ee in errors: errors.remove(ee) benign_errors.append(ee) v.check( '%d unexpected incomparable files' % len(errors), len(errors) == 0) v.check( '(%d of %d expected incomparable files)' % (len(benign_errors), len(expected_errors)), len(benign_errors) == len(expected_errors)) def extract_tarball(v: Verification, channel: Channel, dest: str) -> None: v.status('Extracting tarball %s' % channel.table['nixexprs.tar.xz'].file) shutil.unpack_archive( channel.table['nixexprs.tar.xz'].file, dest) v.ok() def git_checkout(v: Verification, channel: Channel, dest: str) -> None: v.status('Checking out corresponding git revision') git = subprocess.Popen(['git', '-C', channel.git_cachedir, 'archive', channel.git_revision], stdout=subprocess.PIPE) tar = subprocess.Popen( ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout) git.stdout.close() tar.wait() git.wait() v.result(git.returncode == 0 and tar.returncode == 0) def check_channel_metadata( v: Verification, channel: Channel, channel_contents: str) -> None: v.status('Verifying git commit in channel tarball') v.result( open( os.path.join( channel_contents, channel.release_name, '.git-revision')).read(999) == channel.git_revision) v.status( 'Verifying version-suffix is a suffix of release name %s:' % channel.release_name) version_suffix = open( os.path.join( channel_contents, channel.release_name, '.version-suffix')).read(999) v.status(version_suffix) v.result(channel.release_name.endswith(version_suffix)) def check_channel_contents(v: Verification, channel: Channel) -> None: with tempfile.TemporaryDirectory() as channel_contents, \ tempfile.TemporaryDirectory() as git_contents: extract_tarball(v, channel, channel_contents) check_channel_metadata(v, channel, channel_contents) git_checkout(v, channel, git_contents) compare_tarball_and_git(v, channel, channel_contents, git_contents) v.status('Removing temporary directories') v.ok() def pin_channel(v: Verification, channel: Channel) -> None: fetch(v, channel) parse_channel(v, channel) fetch_resources(v, channel) git_fetch(v, channel) check_channel_contents(v, channel) def make_channel(conf: configparser.SectionProxy) -> Channel: channel = Channel(**dict(conf.items())) if hasattr(channel, 'git_revision'): channel.old_git_revision = channel.git_revision del channel.git_revision return channel def pin(args: argparse.Namespace) -> None: v = Verification() config = configparser.ConfigParser() config.read_file(open(args.channels_file), args.channels_file) for section in config.sections(): channel = make_channel(config[section]) if 'channel_url' in config[section]: pin_channel(v, channel) config[section]['name'] = channel.release_name config[section]['tarball_url'] = channel.table['nixexprs.tar.xz'].absolute_url config[section]['tarball_sha256'] = channel.table['nixexprs.tar.xz'].digest else: git_fetch(v, channel) config[section]['git_revision'] = channel.git_revision with open(args.channels_file, 'w') as configfile: config.write(configfile) def main() -> None: parser = argparse.ArgumentParser(prog='pinch') subparsers = parser.add_subparsers(dest='mode', required=True) parser_pin = subparsers.add_parser('pin') parser_pin.add_argument('channels_file', type=str) parser_pin.set_defaults(func=pin) args = parser.parse_args() args.func(args) main()