import argparse import configparser import filecmp import functools import getpass import hashlib import operator import os import os.path import shlex import shutil import subprocess import sys import tarfile import tempfile import types import urllib.parse import urllib.request import xml.dom.minidom from typing import ( Callable, Dict, Iterable, List, Mapping, NamedTuple, NewType, Optional, Set, Tuple, Type, TypeVar, Union, ) import git_cache # Use xdg module when it's less painful to have as a dependency class XDG(NamedTuple): XDG_CACHE_HOME: str xdg = XDG( XDG_CACHE_HOME=os.getenv( 'XDG_CACHE_HOME', os.path.expanduser('~/.cache'))) class VerificationError(Exception): pass class Verification: def __init__(self) -> None: self.line_length = 0 def status(self, s: str) -> None: print(s, end=' ', file=sys.stderr, flush=True) self.line_length += 1 + len(s) # Unicode?? @staticmethod def _color(s: str, c: int) -> str: return '\033[%2dm%s\033[00m' % (c, s) def result(self, r: bool) -> None: message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r] length = len(message) cols = shutil.get_terminal_size().columns or 80 pad = (cols - (self.line_length + length)) % cols print(' ' * pad + self._color(message, color), file=sys.stderr) self.line_length = 0 if not r: raise VerificationError() def check(self, s: str, r: bool) -> None: self.status(s) self.result(r) def ok(self) -> None: self.result(True) Digest16 = NewType('Digest16', str) Digest32 = NewType('Digest32', str) class ChannelTableEntry(types.SimpleNamespace): absolute_url: str digest: Digest16 file: str size: int url: str class AliasPin(NamedTuple): pass class SymlinkPin(NamedTuple): @property def release_name(self) -> str: return 'link' class GitPin(NamedTuple): git_revision: str release_name: str class ChannelPin(NamedTuple): git_revision: str release_name: str tarball_url: str tarball_sha256: str Pin = Union[AliasPin, SymlinkPin, GitPin, ChannelPin] def copy_to_nix_store(v: Verification, filename: str) -> str: v.status('Putting tarball in Nix store') process = subprocess.run( ['nix-store', '--add', filename], stdout=subprocess.PIPE) v.result(process.returncode == 0) return process.stdout.decode().strip() # type: ignore # (for old mypy) def symlink_archive(v: Verification, path: str) -> str: with tempfile.TemporaryDirectory() as td: archive_filename = os.path.join(td, 'link.tar.gz') os.symlink(path, os.path.join(td, 'link')) with tarfile.open(archive_filename, mode='x:gz') as t: t.add(os.path.join(td, 'link'), arcname='link') return copy_to_nix_store(v, archive_filename) class AliasSearchPath(NamedTuple): alias_of: str # pylint: disable=no-self-use def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin: return AliasPin() class SymlinkSearchPath(NamedTuple): path: str # pylint: disable=no-self-use def pin(self, _: Verification, __: Optional[Pin]) -> SymlinkPin: return SymlinkPin() def fetch(self, v: Verification, _: Pin) -> str: return symlink_archive(v, self.path) class GitSearchPath(NamedTuple): git_ref: str git_repo: str def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin: _, new_revision = git_cache.fetch(self.git_repo, self.git_ref) if old_pin is not None: assert isinstance(old_pin, GitPin) verify_git_ancestry(v, self, old_pin.git_revision, new_revision) return GitPin(release_name=git_revision_name(v, self, new_revision), git_revision=new_revision) def fetch(self, v: Verification, pin: Pin) -> str: assert isinstance(pin, GitPin) git_cache.ensure_rev_available( self.git_repo, self.git_ref, pin.git_revision) return git_get_tarball(v, self, pin) class ChannelSearchPath(NamedTuple): channel_url: str git_ref: str git_repo: str def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin: if old_pin is not None: assert isinstance(old_pin, ChannelPin) channel_html, forwarded_url = fetch_channel(v, self) table, new_gitpin = parse_channel(v, channel_html) if old_pin is not None and old_pin.git_revision == new_gitpin.git_revision: return old_pin fetch_resources(v, new_gitpin, forwarded_url, table) git_cache.ensure_rev_available( self.git_repo, self.git_ref, new_gitpin.git_revision) if old_pin is not None: verify_git_ancestry( v, self, old_pin.git_revision, new_gitpin.git_revision) check_channel_contents(v, self, table, new_gitpin) return ChannelPin( release_name=new_gitpin.release_name, tarball_url=table['nixexprs.tar.xz'].absolute_url, tarball_sha256=table['nixexprs.tar.xz'].digest, git_revision=new_gitpin.git_revision) # pylint: disable=no-self-use def fetch(self, v: Verification, pin: Pin) -> str: assert isinstance(pin, ChannelPin) return fetch_with_nix_prefetch_url( v, pin.tarball_url, Digest16(pin.tarball_sha256)) SearchPath = Union[AliasSearchPath, SymlinkSearchPath, GitSearchPath, ChannelSearchPath] TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath] def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]: def throw(error: OSError) -> None: raise error def join(x: str, y: str) -> str: return y if x == '.' else os.path.join(x, y) def recursive_files(d: str) -> Iterable[str]: all_files: List[str] = [] for path, dirs, files in os.walk(d, onerror=throw): rel = os.path.relpath(path, start=d) all_files.extend(join(rel, f) for f in files) for dir_or_link in dirs: if os.path.islink(join(path, dir_or_link)): all_files.append(join(rel, dir_or_link)) return all_files def exclude_dot_git(files: Iterable[str]) -> Iterable[str]: return (f for f in files if not f.startswith('.git/')) files = functools.reduce( operator.or_, (set( exclude_dot_git( recursive_files(x))) for x in [a, b])) return filecmp.cmpfiles(a, b, files, shallow=False) def fetch_channel( v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]: v.status('Fetching channel') request = urllib.request.urlopen(channel.channel_url, timeout=10) channel_html = request.read().decode() forwarded_url = request.geturl() v.result(request.status == 200) # type: ignore # (for old mypy) v.check('Got forwarded', channel.channel_url != forwarded_url) return channel_html, forwarded_url def parse_channel(v: Verification, channel_html: str) \ -> Tuple[Dict[str, ChannelTableEntry], GitPin]: v.status('Parsing channel description as XML') d = xml.dom.minidom.parseString(channel_html) v.ok() v.status('Extracting release name:') title_name = d.getElementsByTagName( 'title')[0].firstChild.nodeValue.split()[2] h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2] v.status(title_name) v.result(title_name == h1_name) v.status('Extracting git commit:') git_commit_node = d.getElementsByTagName('tt')[0] git_revision = git_commit_node.firstChild.nodeValue v.status(git_revision) v.ok() v.status('Verifying git commit label') v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ') v.status('Parsing table') table: Dict[str, ChannelTableEntry] = {} for row in d.getElementsByTagName('tr')[1:]: name = row.childNodes[0].firstChild.firstChild.nodeValue url = row.childNodes[0].firstChild.getAttribute('href') size = int(row.childNodes[1].firstChild.nodeValue) digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue) table[name] = ChannelTableEntry(url=url, digest=digest, size=size) v.ok() return table, GitPin(release_name=title_name, git_revision=git_revision) def digest_string(s: bytes) -> Digest16: return Digest16(hashlib.sha256(s).hexdigest()) def digest_file(filename: str) -> Digest16: hasher = hashlib.sha256() with open(filename, 'rb') as f: # pylint: disable=cell-var-from-loop for block in iter(lambda: f.read(4096), b''): hasher.update(block) return Digest16(hasher.hexdigest()) _experimental_flag_needed = None def _nix_command(v: Verification) -> List[str]: global _experimental_flag_needed if _experimental_flag_needed is None: v.status('Checking Nix version') process = subprocess.run(['nix', '--help'], stdout=subprocess.PIPE) v.result(process.returncode == 0) _experimental_flag_needed = b'--experimental-features' in process.stdout return ['nix', '--experimental-features', 'nix-command'] if _experimental_flag_needed else ['nix'] def to_Digest16(v: Verification, digest32: Digest32) -> Digest16: v.status('Converting digest to base16') process = subprocess.run(_nix_command(v) + [ 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE) v.result(process.returncode == 0) return Digest16(process.stdout.decode().strip()) def to_Digest32(v: Verification, digest16: Digest16) -> Digest32: v.status('Converting digest to base32') process = subprocess.run(_nix_command(v) + [ 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE) v.result(process.returncode == 0) return Digest32(process.stdout.decode().strip()) def fetch_with_nix_prefetch_url( v: Verification, url: str, digest: Digest16) -> str: v.status('Fetching %s' % url) process = subprocess.run( ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE) v.result(process.returncode == 0) prefetch_digest, path, empty = process.stdout.decode().split('\n') assert empty == '' v.check("Verifying nix-prefetch-url's digest", to_Digest16(v, Digest32(prefetch_digest)) == digest) v.status("Verifying file digest") file_digest = digest_file(path) v.result(file_digest == digest) return path # type: ignore # (for old mypy) def fetch_resources( v: Verification, pin: GitPin, forwarded_url: str, table: Dict[str, ChannelTableEntry]) -> None: for resource in ['git-revision', 'nixexprs.tar.xz']: fields = table[resource] fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url) fields.file = fetch_with_nix_prefetch_url( v, fields.absolute_url, fields.digest) v.status('Verifying git commit on main page matches git commit in table') v.result(open(table['git-revision'].file).read(999) == pin.git_revision) def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str: return os.path.join( xdg.XDG_CACHE_HOME, 'pinch/git-tarball', '%s-%s-%s' % (digest_string(channel.git_repo.encode()), pin.git_revision, pin.release_name)) def verify_git_ancestry( v: Verification, channel: TarrableSearchPath, old_revision: str, new_revision: str) -> None: cachedir = git_cache.git_cachedir(channel.git_repo) v.status('Verifying rev is an ancestor of previous rev %s' % old_revision) process = subprocess.run(['git', '-C', cachedir, 'merge-base', '--is-ancestor', old_revision, new_revision]) v.result(process.returncode == 0) def compare_tarball_and_git( v: Verification, pin: GitPin, channel_contents: str, git_contents: str) -> None: v.status('Comparing channel tarball with git checkout') match, mismatch, errors = compare(os.path.join( channel_contents, pin.release_name), git_contents) v.ok() v.check('%d files match' % len(match), len(match) > 0) v.check('%d files differ' % len(mismatch), len(mismatch) == 0) expected_errors = [ '.git-revision', '.version-suffix', 'nixpkgs', 'programs.sqlite', 'svn-revision'] benign_errors = [] for ee in expected_errors: if ee in errors: errors.remove(ee) benign_errors.append(ee) v.check( '%d unexpected incomparable files' % len(errors), len(errors) == 0) v.check( '(%d of %d expected incomparable files)' % (len(benign_errors), len(expected_errors)), len(benign_errors) == len(expected_errors)) def extract_tarball( v: Verification, table: Dict[str, ChannelTableEntry], dest: str) -> None: v.status('Extracting tarball %s' % table['nixexprs.tar.xz'].file) shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest) v.ok() def git_checkout( v: Verification, channel: TarrableSearchPath, pin: GitPin, dest: str) -> None: v.status('Checking out corresponding git revision') git = subprocess.Popen(['git', '-C', git_cache.git_cachedir(channel.git_repo), 'archive', pin.git_revision], stdout=subprocess.PIPE) tar = subprocess.Popen( ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout) if git.stdout: git.stdout.close() tar.wait() git.wait() v.result(git.returncode == 0 and tar.returncode == 0) def git_get_tarball( v: Verification, channel: TarrableSearchPath, pin: GitPin) -> str: cache_file = tarball_cache_file(channel, pin) if os.path.exists(cache_file): cached_tarball = open(cache_file).read(9999) if os.path.exists(cached_tarball): return cached_tarball with tempfile.TemporaryDirectory() as output_dir: output_filename = os.path.join( output_dir, pin.release_name + '.tar.xz') with open(output_filename, 'w') as output_file: v.status( 'Generating tarball for git revision %s' % pin.git_revision) git = subprocess.Popen(['git', '-C', git_cache.git_cachedir(channel.git_repo), 'archive', '--prefix=%s/' % pin.release_name, pin.git_revision], stdout=subprocess.PIPE) xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file) xz.wait() git.wait() v.result(git.returncode == 0 and xz.returncode == 0) store_tarball = copy_to_nix_store(v, output_filename) os.makedirs(os.path.dirname(cache_file), exist_ok=True) open(cache_file, 'w').write(store_tarball) return store_tarball # type: ignore # (for old mypy) def check_channel_metadata( v: Verification, pin: GitPin, channel_contents: str) -> None: v.status('Verifying git commit in channel tarball') v.result( open( os.path.join( channel_contents, pin.release_name, '.git-revision')).read(999) == pin.git_revision) v.status( 'Verifying version-suffix is a suffix of release name %s:' % pin.release_name) version_suffix = open( os.path.join( channel_contents, pin.release_name, '.version-suffix')).read(999) v.status(version_suffix) v.result(pin.release_name.endswith(version_suffix)) def check_channel_contents( v: Verification, channel: TarrableSearchPath, table: Dict[str, ChannelTableEntry], pin: GitPin) -> None: with tempfile.TemporaryDirectory() as channel_contents, \ tempfile.TemporaryDirectory() as git_contents: extract_tarball(v, table, channel_contents) check_channel_metadata(v, pin, channel_contents) git_checkout(v, channel, pin, git_contents) compare_tarball_and_git(v, pin, channel_contents, git_contents) v.status('Removing temporary directories') v.ok() def git_revision_name( v: Verification, channel: TarrableSearchPath, git_revision: str) -> str: v.status('Getting commit date') process = subprocess.run(['git', '-C', git_cache.git_cachedir(channel.git_repo), 'log', '-n1', '--format=%ct-%h', '--abbrev=11', '--no-show-signature', git_revision], stdout=subprocess.PIPE) v.result(process.returncode == 0 and process.stdout != b'') return '%s-%s' % (os.path.basename(channel.git_repo), process.stdout.decode().strip()) K = TypeVar('K') V = TypeVar('V') def partition_dict(pred: Callable[[K, V], bool], d: Dict[K, V]) -> Tuple[Dict[K, V], Dict[K, V]]: selected: Dict[K, V] = {} remaining: Dict[K, V] = {} for k, v in d.items(): if pred(k, v): selected[k] = v else: remaining[k] = v return selected, remaining def filter_dict(d: Dict[K, V], fields: Set[K] ) -> Tuple[Dict[K, V], Dict[K, V]]: return partition_dict(lambda k, v: k in fields, d) def read_config_section( conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]: mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = { 'alias': (AliasSearchPath, AliasPin), 'channel': (ChannelSearchPath, ChannelPin), 'git': (GitSearchPath, GitPin), 'symlink': (SymlinkSearchPath, SymlinkPin), } SP, P = mapping[conf['type']] _, all_fields = filter_dict(dict(conf.items()), set(['type'])) pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields)) # Error suppression works around https://github.com/python/mypy/issues/9007 pin_present = pin_fields != {} or P._fields == () pin = P(**pin_fields) if pin_present else None # type: ignore return SP(**remaining_fields), pin def read_pinned_config_section( section: str, conf: configparser.SectionProxy) -> Tuple[SearchPath, Pin]: sp, pin = read_config_section(conf) if pin is None: raise Exception( 'Cannot update unpinned channel "%s" (Run "pin" before "update")' % section) return sp, pin def read_config(filename: str) -> configparser.ConfigParser: config = configparser.ConfigParser() config.read_file(open(filename), filename) return config def read_config_files( filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]: merged_config: Dict[str, configparser.SectionProxy] = {} for file in filenames: config = read_config(file) for section in config.sections(): if section in merged_config: raise Exception('Duplicate channel "%s"' % section) merged_config[section] = config[section] return merged_config def pinCommand(args: argparse.Namespace) -> None: v = Verification() config = read_config(args.channels_file) for section in config.sections(): if args.channels and section not in args.channels: continue sp, old_pin = read_config_section(config[section]) config[section].update(sp.pin(v, old_pin)._asdict()) with open(args.channels_file, 'w') as configfile: config.write(configfile) def updateCommand(args: argparse.Namespace) -> None: v = Verification() exprs: Dict[str, str] = {} config = { section: read_pinned_config_section(section, conf) for section, conf in read_config_files( args.channels_file).items()} alias, nonalias = partition_dict( lambda k, v: isinstance(v[0], AliasSearchPath), config) for section, (sp, pin) in nonalias.items(): assert not isinstance(sp, AliasSearchPath) # mypy can't see through assert not isinstance(pin, AliasPin) # partition_dict() tarball = sp.fetch(v, pin) exprs[section] = ( 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' % (pin.release_name, tarball)) for section, (sp, pin) in alias.items(): assert isinstance(sp, AliasSearchPath) # For mypy exprs[section] = exprs[sp.alias_of] command = [ 'nix-env', '--profile', args.profile, '--show-trace', '--file', '', '--install', '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())] if args.dry_run: print(' '.join(map(shlex.quote, command))) else: v.status('Installing channels with nix-env') process = subprocess.run(command) v.result(process.returncode == 0) def main() -> None: parser = argparse.ArgumentParser(prog='pinch') subparsers = parser.add_subparsers(dest='mode', required=True) parser_pin = subparsers.add_parser('pin') parser_pin.add_argument('channels_file', type=str) parser_pin.add_argument('channels', type=str, nargs='*') parser_pin.set_defaults(func=pinCommand) parser_update = subparsers.add_parser('update') parser_update.add_argument('--dry-run', action='store_true') parser_update.add_argument('--profile', default=( '/nix/var/nix/profiles/per-user/%s/channels' % getpass.getuser())) parser_update.add_argument('channels_file', type=str, nargs='+') parser_update.set_defaults(func=updateCommand) args = parser.parse_args() args.func(args) if __name__ == '__main__': main()