]> git.scottworley.com Git - pinch/blame - pinch.py
README is markdown
[pinch] / pinch.py
CommitLineData
0e5e611d 1import argparse
f15e458d 2import configparser
2f96f32a
SW
3import filecmp
4import functools
736c25eb 5import getpass
2f96f32a
SW
6import hashlib
7import operator
8import os
9import os.path
9a78329e 10import shlex
2f96f32a 11import shutil
73bec7e8 12import subprocess
9d7844bb 13import sys
0afcdb2a 14import tarfile
2f96f32a 15import tempfile
89e79125 16import types
2f96f32a
SW
17import urllib.parse
18import urllib.request
19import xml.dom.minidom
20
21from typing import (
9d2c406b 22 Callable,
2f96f32a
SW
23 Dict,
24 Iterable,
25 List,
7f4c3ace 26 Mapping,
0a4ff8dd 27 NamedTuple,
73bec7e8 28 NewType,
d7cfdb22 29 Optional,
567a6783 30 Set,
2f96f32a 31 Tuple,
7f4c3ace 32 Type,
567a6783 33 TypeVar,
0a4ff8dd 34 Union,
2f96f32a
SW
35)
36
d06918bc
SW
37import git_cache
38
3603dde2
SW
39# Use xdg module when it's less painful to have as a dependency
40
41
9f936f16 42class XDG(NamedTuple):
3603dde2
SW
43 XDG_CACHE_HOME: str
44
45
46xdg = XDG(
47 XDG_CACHE_HOME=os.getenv(
48 'XDG_CACHE_HOME',
49 os.path.expanduser('~/.cache')))
26125a28
SW
50
51
2f96f32a
SW
52class VerificationError(Exception):
53 pass
54
55
56class Verification:
57
58 def __init__(self) -> None:
59 self.line_length = 0
60
61 def status(self, s: str) -> None:
9d7844bb 62 print(s, end=' ', file=sys.stderr, flush=True)
2f96f32a
SW
63 self.line_length += 1 + len(s) # Unicode??
64
65 @staticmethod
66 def _color(s: str, c: int) -> str:
cb28d8e5 67 return f'\033[{c:2d}m{s}\033[00m'
2f96f32a
SW
68
69 def result(self, r: bool) -> None:
70 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
71 length = len(message)
dd1026fe 72 cols = shutil.get_terminal_size().columns or 80
2f96f32a 73 pad = (cols - (self.line_length + length)) % cols
9d7844bb 74 print(' ' * pad + self._color(message, color), file=sys.stderr)
2f96f32a
SW
75 self.line_length = 0
76 if not r:
77 raise VerificationError()
78
79 def check(self, s: str, r: bool) -> None:
80 self.status(s)
81 self.result(r)
82
83 def ok(self) -> None:
84 self.result(True)
85
86
faff8642
SW
87Digest16 = NewType('Digest16', str)
88Digest32 = NewType('Digest32', str)
89
90
91class ChannelTableEntry(types.SimpleNamespace):
92 absolute_url: str
93 digest: Digest16
94 file: str
95 size: int
96 url: str
97
98
0a4ff8dd
SW
99class AliasPin(NamedTuple):
100 pass
101
102
0afcdb2a
SW
103class SymlinkPin(NamedTuple):
104 @property
105 def release_name(self) -> str:
106 return 'link'
107
108
0a4ff8dd
SW
109class GitPin(NamedTuple):
110 git_revision: str
faff8642
SW
111 release_name: str
112
0a4ff8dd
SW
113
114class ChannelPin(NamedTuple):
115 git_revision: str
116 release_name: str
117 tarball_url: str
118 tarball_sha256: str
119
120
0afcdb2a
SW
121Pin = Union[AliasPin, SymlinkPin, GitPin, ChannelPin]
122
123
124def copy_to_nix_store(v: Verification, filename: str) -> str:
125 v.status('Putting tarball in Nix store')
126 process = subprocess.run(
127 ['nix-store', '--add', filename], stdout=subprocess.PIPE)
128 v.result(process.returncode == 0)
530104d7 129 return process.stdout.decode().strip() # type: ignore # (for old mypy)
0a4ff8dd
SW
130
131
96063a51
SW
132def symlink_archive(v: Verification, path: str) -> str:
133 with tempfile.TemporaryDirectory() as td:
134 archive_filename = os.path.join(td, 'link.tar.gz')
135 os.symlink(path, os.path.join(td, 'link'))
136 with tarfile.open(archive_filename, mode='x:gz') as t:
137 t.add(os.path.join(td, 'link'), arcname='link')
138 return copy_to_nix_store(v, archive_filename)
139
140
567a6783 141class AliasSearchPath(NamedTuple):
faff8642 142 alias_of: str
2fa9cbea 143
567a6783
SW
144 # pylint: disable=no-self-use
145 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
0a4ff8dd 146 return AliasPin()
2fa9cbea
SW
147
148
0afcdb2a
SW
149class SymlinkSearchPath(NamedTuple):
150 path: str
151
152 # pylint: disable=no-self-use
153 def pin(self, _: Verification, __: Optional[Pin]) -> SymlinkPin:
154 return SymlinkPin()
155
156 def fetch(self, v: Verification, _: Pin) -> str:
96063a51 157 return symlink_archive(v, self.path)
0afcdb2a
SW
158
159
567a6783 160class GitSearchPath(NamedTuple):
faff8642
SW
161 git_ref: str
162 git_repo: str
faff8642 163
567a6783 164 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
d06918bc 165 _, new_revision = git_cache.fetch(self.git_repo, self.git_ref)
567a6783
SW
166 if old_pin is not None:
167 assert isinstance(old_pin, GitPin)
d06918bc 168 verify_git_ancestry(v, self, old_pin.git_revision, new_revision)
d7cfdb22
SW
169 return GitPin(release_name=git_revision_name(v, self, new_revision),
170 git_revision=new_revision)
4a82be40 171
567a6783
SW
172 def fetch(self, v: Verification, pin: Pin) -> str:
173 assert isinstance(pin, GitPin)
d06918bc
SW
174 git_cache.ensure_rev_available(
175 self.git_repo, self.git_ref, pin.git_revision)
567a6783 176 return git_get_tarball(v, self, pin)
b3ea39b8 177
b3ea39b8 178
567a6783
SW
179class ChannelSearchPath(NamedTuple):
180 channel_url: str
181 git_ref: str
182 git_repo: str
4a82be40 183
567a6783
SW
184 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
185 if old_pin is not None:
186 assert isinstance(old_pin, ChannelPin)
a67cfec9 187
567a6783
SW
188 channel_html, forwarded_url = fetch_channel(v, self)
189 table, new_gitpin = parse_channel(v, channel_html)
53a27350
SW
190 if old_pin is not None and old_pin.git_revision == new_gitpin.git_revision:
191 return old_pin
567a6783 192 fetch_resources(v, new_gitpin, forwarded_url, table)
d06918bc
SW
193 git_cache.ensure_rev_available(
194 self.git_repo, self.git_ref, new_gitpin.git_revision)
195 if old_pin is not None:
196 verify_git_ancestry(
197 v, self, old_pin.git_revision, new_gitpin.git_revision)
567a6783 198 check_channel_contents(v, self, table, new_gitpin)
0a4ff8dd 199 return ChannelPin(
b17278e3 200 release_name=new_gitpin.release_name,
567a6783
SW
201 tarball_url=table['nixexprs.tar.xz'].absolute_url,
202 tarball_sha256=table['nixexprs.tar.xz'].digest,
3258ff2c 203 git_revision=new_gitpin.git_revision)
4a82be40 204
b3ea39b8 205 # pylint: disable=no-self-use
567a6783
SW
206 def fetch(self, v: Verification, pin: Pin) -> str:
207 assert isinstance(pin, ChannelPin)
b3ea39b8
SW
208
209 return fetch_with_nix_prefetch_url(
567a6783
SW
210 v, pin.tarball_url, Digest16(pin.tarball_sha256))
211
212
0afcdb2a
SW
213SearchPath = Union[AliasSearchPath,
214 SymlinkSearchPath,
215 GitSearchPath,
216 ChannelSearchPath]
567a6783 217TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
b3ea39b8 218
4a82be40 219
dc038df0 220def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
2f96f32a
SW
221
222 def throw(error: OSError) -> None:
223 raise error
224
225 def join(x: str, y: str) -> str:
226 return y if x == '.' else os.path.join(x, y)
227
228 def recursive_files(d: str) -> Iterable[str]:
229 all_files: List[str] = []
230 for path, dirs, files in os.walk(d, onerror=throw):
231 rel = os.path.relpath(path, start=d)
232 all_files.extend(join(rel, f) for f in files)
233 for dir_or_link in dirs:
234 if os.path.islink(join(path, dir_or_link)):
235 all_files.append(join(rel, dir_or_link))
236 return all_files
237
238 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
239 return (f for f in files if not f.startswith('.git/'))
240
241 files = functools.reduce(
242 operator.or_, (set(
243 exclude_dot_git(
244 recursive_files(x))) for x in [a, b]))
245 return filecmp.cmpfiles(a, b, files, shallow=False)
246
247
567a6783
SW
248def fetch_channel(
249 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
cb28d8e5 250 v.status(f'Fetching channel from {channel.channel_url}')
fb27ccc7
SW
251 with urllib.request.urlopen(channel.channel_url, timeout=10) as request:
252 channel_html = request.read().decode()
253 forwarded_url = request.geturl()
254 v.result(request.status == 200)
567a6783
SW
255 v.check('Got forwarded', channel.channel_url != forwarded_url)
256 return channel_html, forwarded_url
2f96f32a
SW
257
258
567a6783
SW
259def parse_channel(v: Verification, channel_html: str) \
260 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
2f96f32a 261 v.status('Parsing channel description as XML')
567a6783 262 d = xml.dom.minidom.parseString(channel_html)
2f96f32a
SW
263 v.ok()
264
3e6421c4
SW
265 v.status('Extracting release name:')
266 title_name = d.getElementsByTagName(
267 'title')[0].firstChild.nodeValue.split()[2]
268 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
269 v.status(title_name)
270 v.result(title_name == h1_name)
3e6421c4
SW
271
272 v.status('Extracting git commit:')
2f96f32a 273 git_commit_node = d.getElementsByTagName('tt')[0]
3258ff2c
SW
274 git_revision = git_commit_node.firstChild.nodeValue
275 v.status(git_revision)
2f96f32a
SW
276 v.ok()
277 v.status('Verifying git commit label')
278 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
279
280 v.status('Parsing table')
567a6783 281 table: Dict[str, ChannelTableEntry] = {}
2f96f32a
SW
282 for row in d.getElementsByTagName('tr')[1:]:
283 name = row.childNodes[0].firstChild.firstChild.nodeValue
284 url = row.childNodes[0].firstChild.getAttribute('href')
285 size = int(row.childNodes[1].firstChild.nodeValue)
73bec7e8 286 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
567a6783 287 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
2f96f32a 288 v.ok()
567a6783 289 return table, GitPin(release_name=title_name, git_revision=git_revision)
2f96f32a
SW
290
291
dc038df0
SW
292def digest_string(s: bytes) -> Digest16:
293 return Digest16(hashlib.sha256(s).hexdigest())
294
295
73bec7e8
SW
296def digest_file(filename: str) -> Digest16:
297 hasher = hashlib.sha256()
298 with open(filename, 'rb') as f:
299 # pylint: disable=cell-var-from-loop
300 for block in iter(lambda: f.read(4096), b''):
301 hasher.update(block)
302 return Digest16(hasher.hexdigest())
303
304
b1e8c1b0
SW
305@functools.lru_cache
306def _experimental_flag_needed(v: Verification) -> bool:
307 v.status('Checking Nix version')
308 process = subprocess.run(['nix', '--help'], stdout=subprocess.PIPE)
309 v.result(process.returncode == 0)
310 return b'--experimental-features' in process.stdout
1d48f551
SW
311
312
313def _nix_command(v: Verification) -> List[str]:
1d48f551 314 return ['nix', '--experimental-features',
b1e8c1b0 315 'nix-command'] if _experimental_flag_needed(v) else ['nix']
1d48f551
SW
316
317
73bec7e8
SW
318def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
319 v.status('Converting digest to base16')
1d48f551
SW
320 process = subprocess.run(_nix_command(v) + [
321 'to-base16',
322 '--type',
323 'sha256',
324 digest32],
325 stdout=subprocess.PIPE)
73bec7e8
SW
326 v.result(process.returncode == 0)
327 return Digest16(process.stdout.decode().strip())
328
329
330def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
331 v.status('Converting digest to base32')
1d48f551
SW
332 process = subprocess.run(_nix_command(v) + [
333 'to-base32',
334 '--type',
335 'sha256',
336 digest16],
337 stdout=subprocess.PIPE)
73bec7e8
SW
338 v.result(process.returncode == 0)
339 return Digest32(process.stdout.decode().strip())
340
341
342def fetch_with_nix_prefetch_url(
343 v: Verification,
344 url: str,
345 digest: Digest16) -> str:
cb28d8e5 346 v.status(f'Fetching {url}')
73bec7e8 347 process = subprocess.run(
ba596fc0 348 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
73bec7e8
SW
349 v.result(process.returncode == 0)
350 prefetch_digest, path, empty = process.stdout.decode().split('\n')
351 assert empty == ''
352 v.check("Verifying nix-prefetch-url's digest",
353 to_Digest16(v, Digest32(prefetch_digest)) == digest)
cb28d8e5 354 v.status(f"Verifying digest of {path}")
73bec7e8
SW
355 file_digest = digest_file(path)
356 v.result(file_digest == digest)
7c4de64c 357 return path # type: ignore # (for old mypy)
2f96f32a 358
73bec7e8 359
9343cf48
SW
360def fetch_resources(
361 v: Verification,
567a6783
SW
362 pin: GitPin,
363 forwarded_url: str,
364 table: Dict[str, ChannelTableEntry]) -> None:
2f96f32a 365 for resource in ['git-revision', 'nixexprs.tar.xz']:
567a6783
SW
366 fields = table[resource]
367 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
e434d96d
SW
368 fields.file = fetch_with_nix_prefetch_url(
369 v, fields.absolute_url, fields.digest)
73bec7e8 370 v.status('Verifying git commit on main page matches git commit in table')
fb27ccc7
SW
371 with open(table['git-revision'].file, encoding='utf-8') as rev_file:
372 v.result(rev_file.read(999) == pin.git_revision)
2f96f32a 373
971d3659 374
b17278e3 375def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
eb0c6f1b
SW
376 return os.path.join(
377 xdg.XDG_CACHE_HOME,
378 'pinch/git-tarball',
cb28d8e5 379 f'{digest_string(channel.git_repo.encode())}-{pin.git_revision}-{pin.release_name}')
971d3659
SW
380
381
d7cfdb22
SW
382def verify_git_ancestry(
383 v: Verification,
384 channel: TarrableSearchPath,
d06918bc
SW
385 old_revision: str,
386 new_revision: str) -> None:
387 cachedir = git_cache.git_cachedir(channel.git_repo)
cb28d8e5 388 v.status(f'Verifying rev is an ancestor of previous rev {old_revision}')
971d3659
SW
389 process = subprocess.run(['git',
390 '-C',
391 cachedir,
392 'merge-base',
393 '--is-ancestor',
d06918bc
SW
394 old_revision,
395 new_revision])
971d3659
SW
396 v.result(process.returncode == 0)
397
dc038df0 398
925c801b
SW
399def compare_tarball_and_git(
400 v: Verification,
a72fdca9 401 pin: GitPin,
925c801b
SW
402 channel_contents: str,
403 git_contents: str) -> None:
404 v.status('Comparing channel tarball with git checkout')
405 match, mismatch, errors = compare(os.path.join(
a72fdca9 406 channel_contents, pin.release_name), git_contents)
925c801b 407 v.ok()
cb28d8e5
SW
408 v.check(f'{len(match)} files match', len(match) > 0)
409 v.check(f'{len(mismatch)} files differ', len(mismatch) == 0)
925c801b
SW
410 expected_errors = [
411 '.git-revision',
412 '.version-suffix',
413 'nixpkgs',
414 'programs.sqlite',
415 'svn-revision']
416 benign_errors = []
417 for ee in expected_errors:
418 if ee in errors:
419 errors.remove(ee)
420 benign_errors.append(ee)
cb28d8e5 421 v.check(f'{len(errors)} unexpected incomparable files', len(errors) == 0)
925c801b 422 v.check(
cb28d8e5 423 f'({len(benign_errors)} of {len(expected_errors)} expected incomparable files)',
925c801b
SW
424 len(benign_errors) == len(expected_errors))
425
426
7d2c278f
SW
427def extract_tarball(
428 v: Verification,
567a6783 429 table: Dict[str, ChannelTableEntry],
7d2c278f 430 dest: str) -> None:
cb28d8e5 431 v.status(f"Extracting tarball {table['nixexprs.tar.xz'].file}")
567a6783 432 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
925c801b
SW
433 v.ok()
434
435
7d2c278f
SW
436def git_checkout(
437 v: Verification,
438 channel: TarrableSearchPath,
3258ff2c 439 pin: GitPin,
7d2c278f 440 dest: str) -> None:
925c801b 441 v.status('Checking out corresponding git revision')
fb27ccc7
SW
442 with subprocess.Popen(
443 ['git', '-C', git_cache.git_cachedir(channel.git_repo), 'archive', pin.git_revision],
444 stdout=subprocess.PIPE) as git:
445 with subprocess.Popen(['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout) as tar:
446 if git.stdout:
447 git.stdout.close()
448 tar.wait()
449 git.wait()
450 v.result(git.returncode == 0 and tar.returncode == 0)
925c801b
SW
451
452
9343cf48
SW
453def git_get_tarball(
454 v: Verification,
455 channel: TarrableSearchPath,
456 pin: GitPin) -> str:
b17278e3 457 cache_file = tarball_cache_file(channel, pin)
eb0c6f1b 458 if os.path.exists(cache_file):
fb27ccc7
SW
459 with open(cache_file, encoding='utf-8') as f:
460 cached_tarball = f.read(9999)
461 if os.path.exists(cached_tarball):
462 return cached_tarball
eb0c6f1b 463
736c25eb
SW
464 with tempfile.TemporaryDirectory() as output_dir:
465 output_filename = os.path.join(
9343cf48 466 output_dir, pin.release_name + '.tar.xz')
fb27ccc7 467 with open(output_filename, 'w', encoding='utf-8') as output_file:
cb28d8e5 468 v.status(f'Generating tarball for git revision {pin.git_revision}')
fb27ccc7
SW
469 with subprocess.Popen(
470 ['git', '-C', git_cache.git_cachedir(channel.git_repo),
cb28d8e5 471 'archive', f'--prefix={pin.release_name}/', pin.git_revision],
fb27ccc7
SW
472 stdout=subprocess.PIPE) as git:
473 with subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file) as xz:
474 xz.wait()
475 git.wait()
476 v.result(git.returncode == 0 and xz.returncode == 0)
736c25eb 477
0afcdb2a 478 store_tarball = copy_to_nix_store(v, output_filename)
eb0c6f1b
SW
479
480 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
fb27ccc7
SW
481 with open(cache_file, 'w', encoding='utf-8') as f:
482 f.write(store_tarball)
7c4de64c 483 return store_tarball # type: ignore # (for old mypy)
736c25eb
SW
484
485
f9cd7bdc
SW
486def check_channel_metadata(
487 v: Verification,
a72fdca9 488 pin: GitPin,
f9cd7bdc
SW
489 channel_contents: str) -> None:
490 v.status('Verifying git commit in channel tarball')
fb27ccc7
SW
491 with open(os.path.join(channel_contents, pin.release_name, '.git-revision'),
492 encoding='utf-8') as f:
493 v.result(f.read(999) == pin.git_revision)
f9cd7bdc
SW
494
495 v.status(
cb28d8e5 496 f'Verifying version-suffix is a suffix of release name {pin.release_name}:')
fb27ccc7
SW
497 with open(os.path.join(channel_contents, pin.release_name, '.version-suffix'),
498 encoding='utf-8') as f:
499 version_suffix = f.read(999)
f9cd7bdc 500 v.status(version_suffix)
a72fdca9 501 v.result(pin.release_name.endswith(version_suffix))
f9cd7bdc
SW
502
503
7d2c278f
SW
504def check_channel_contents(
505 v: Verification,
a72fdca9 506 channel: TarrableSearchPath,
567a6783 507 table: Dict[str, ChannelTableEntry],
a72fdca9 508 pin: GitPin) -> None:
dc038df0
SW
509 with tempfile.TemporaryDirectory() as channel_contents, \
510 tempfile.TemporaryDirectory() as git_contents:
925c801b 511
567a6783 512 extract_tarball(v, table, channel_contents)
a72fdca9 513 check_channel_metadata(v, pin, channel_contents)
f9cd7bdc 514
3258ff2c 515 git_checkout(v, channel, pin, git_contents)
925c801b 516
a72fdca9 517 compare_tarball_and_git(v, pin, channel_contents, git_contents)
925c801b 518
dc038df0 519 v.status('Removing temporary directories')
2f96f32a
SW
520 v.ok()
521
522
d7cfdb22
SW
523def git_revision_name(
524 v: Verification,
525 channel: TarrableSearchPath,
526 git_revision: str) -> str:
e3cae769
SW
527 v.status('Getting commit date')
528 process = subprocess.run(['git',
529 '-C',
d06918bc 530 git_cache.git_cachedir(channel.git_repo),
bed32182 531 'log',
e3cae769
SW
532 '-n1',
533 '--format=%ct-%h',
534 '--abbrev=11',
88af5903 535 '--no-show-signature',
d7cfdb22 536 git_revision],
ba596fc0 537 stdout=subprocess.PIPE)
de68382a 538 v.result(process.returncode == 0 and process.stdout != b'')
cb28d8e5 539 return f'{os.path.basename(channel.git_repo)}-{process.stdout.decode().strip()}'
e3cae769
SW
540
541
567a6783
SW
542K = TypeVar('K')
543V = TypeVar('V')
544
545
9d2c406b
SW
546def partition_dict(pred: Callable[[K, V], bool],
547 d: Dict[K, V]) -> Tuple[Dict[K, V], Dict[K, V]]:
567a6783
SW
548 selected: Dict[K, V] = {}
549 remaining: Dict[K, V] = {}
550 for k, v in d.items():
9d2c406b 551 if pred(k, v):
567a6783
SW
552 selected[k] = v
553 else:
554 remaining[k] = v
555 return selected, remaining
556
557
9d2c406b
SW
558def filter_dict(d: Dict[K, V], fields: Set[K]
559 ) -> Tuple[Dict[K, V], Dict[K, V]]:
560 return partition_dict(lambda k, v: k in fields, d)
561
562
d815b199 563def read_config_section(
567a6783
SW
564 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
565 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
566 'alias': (AliasSearchPath, AliasPin),
567 'channel': (ChannelSearchPath, ChannelPin),
568 'git': (GitSearchPath, GitPin),
0afcdb2a 569 'symlink': (SymlinkSearchPath, SymlinkPin),
7f4c3ace 570 }
567a6783
SW
571 SP, P = mapping[conf['type']]
572 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
573 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
574 # Error suppression works around https://github.com/python/mypy/issues/9007
9e4ad890 575 pin_present = pin_fields or P._fields == ()
530104d7 576 pin = P(**pin_fields) if pin_present else None # type: ignore
567a6783 577 return SP(**remaining_fields), pin
f8f5b125
SW
578
579
e8bd4979
SW
580def read_pinned_config_section(
581 section: str, conf: configparser.SectionProxy) -> Tuple[SearchPath, Pin]:
582 sp, pin = read_config_section(conf)
583 if pin is None:
584 raise Exception(
cb28d8e5 585 f'Cannot update unpinned channel "{section}" (Run "pin" before "update")')
e8bd4979
SW
586 return sp, pin
587
588
01ba0eb2
SW
589def read_config(filename: str) -> configparser.ConfigParser:
590 config = configparser.ConfigParser()
fb27ccc7
SW
591 with open(filename, encoding='utf-8') as f:
592 config.read_file(f, filename)
01ba0eb2
SW
593 return config
594
595
4603b1a7
SW
596def read_config_files(
597 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
598 merged_config: Dict[str, configparser.SectionProxy] = {}
599 for file in filenames:
600 config = read_config(file)
601 for section in config.sections():
602 if section in merged_config:
cb28d8e5 603 raise Exception('Duplicate channel "{section}"')
4603b1a7
SW
604 merged_config[section] = config[section]
605 return merged_config
606
607
41b87c9c 608def pinCommand(args: argparse.Namespace) -> None:
2f96f32a 609 v = Verification()
01ba0eb2 610 config = read_config(args.channels_file)
5cfa8e11 611 for section in config.sections():
98853153
SW
612 if args.channels and section not in args.channels:
613 continue
736c25eb 614
d815b199 615 sp, old_pin = read_config_section(config[section])
17906b27 616
567a6783 617 config[section].update(sp.pin(v, old_pin)._asdict())
8fca6c28 618
fb27ccc7 619 with open(args.channels_file, 'w', encoding='utf-8') as configfile:
e434d96d 620 config.write(configfile)
2f96f32a
SW
621
622
41b87c9c 623def updateCommand(args: argparse.Namespace) -> None:
736c25eb 624 v = Verification()
da135b07 625 exprs: Dict[str, str] = {}
3b2117a3
SW
626 profile_manifest = os.path.join(args.profile, "manifest.nix")
627 search_paths: List[str] = [
628 "-I", "pinch_profile=" + args.profile,
629 "-I", "pinch_profile_manifest=" + os.readlink(profile_manifest)
630 ] if os.path.exists(profile_manifest) else []
9d2c406b
SW
631 config = {
632 section: read_pinned_config_section(section, conf) for section,
633 conf in read_config_files(
634 args.channels_file).items()}
635 alias, nonalias = partition_dict(
636 lambda k, v: isinstance(v[0], AliasSearchPath), config)
637
436195f0 638 for section, (sp, pin) in sorted(nonalias.items()):
9d2c406b
SW
639 assert not isinstance(sp, AliasSearchPath) # mypy can't see through
640 assert not isinstance(pin, AliasPin) # partition_dict()
567a6783 641 tarball = sp.fetch(v, pin)
cb28d8e5
SW
642 search_paths.extend(
643 ["-I", f"pinch_tarball_for_{pin.release_name}={tarball}"])
4603b1a7 644 exprs[section] = (
cb28d8e5
SW
645 f'f: f {{ name = "{pin.release_name}"; channelName = "%s"; '
646 f'src = builtins.storePath "{tarball}"; }}')
4603b1a7 647
9d2c406b
SW
648 for section, (sp, pin) in alias.items():
649 assert isinstance(sp, AliasSearchPath) # For mypy
650 exprs[section] = exprs[sp.alias_of]
17906b27 651
9a78329e
SW
652 command = [
653 'nix-env',
654 '--profile',
9e8ed0ed 655 args.profile,
9a78329e
SW
656 '--show-trace',
657 '--file',
658 '<nix/unpack-channel.nix>',
659 '--install',
fc168c34 660 '--remove-all',
436195f0
SW
661 ] + search_paths + ['--from-expression'] + [
662 exprs[name] % name for name in sorted(exprs.keys())]
9a78329e
SW
663 if args.dry_run:
664 print(' '.join(map(shlex.quote, command)))
665 else:
666 v.status('Installing channels with nix-env')
667 process = subprocess.run(command)
668 v.result(process.returncode == 0)
736c25eb
SW
669
670
0e5e611d
SW
671def main() -> None:
672 parser = argparse.ArgumentParser(prog='pinch')
673 subparsers = parser.add_subparsers(dest='mode', required=True)
674 parser_pin = subparsers.add_parser('pin')
675 parser_pin.add_argument('channels_file', type=str)
98853153 676 parser_pin.add_argument('channels', type=str, nargs='*')
41b87c9c 677 parser_pin.set_defaults(func=pinCommand)
736c25eb 678 parser_update = subparsers.add_parser('update')
9a78329e 679 parser_update.add_argument('--dry-run', action='store_true')
9e8ed0ed 680 parser_update.add_argument('--profile', default=(
cb28d8e5 681 f'/nix/var/nix/profiles/per-user/{getpass.getuser()}/channels'))
01ba0eb2 682 parser_update.add_argument('channels_file', type=str, nargs='+')
41b87c9c 683 parser_update.set_defaults(func=updateCommand)
0e5e611d
SW
684 args = parser.parse_args()
685 args.func(args)
686
687
b5964ec3
SW
688if __name__ == '__main__':
689 main()