+from contextlib import contextmanager
+import json
+import os
+import subprocess
+import sys
+import xml.sax
+import xml.sax.handler
+
+
+@contextmanager
+def log(msg):
+ print(msg, file=sys.stderr, end='', flush=True)
+ try:
+ yield
+ finally:
+ print('\r', file=sys.stderr, end='', flush=True)
+
+
+class ParseNixStoreQueryGraphML(xml.sax.handler.ContentHandler):
+ def __init__(self):
+ self.roots = set()
+ self.non_roots = set()
+ self.deps = {}
+
+ def startElement(self, name, attrs):
+ if name == "edge":
+ source = attrs.getValue("source")
+ target = attrs.getValue("target")
+ self.non_roots.add(target)
+ self.deps.setdefault(source, []).append(target)
+ if target in self.roots:
+ self.roots.remove(target)
+ if source not in self.non_roots:
+ self.roots.add(source)
+
+
+def getDeps(drv):
+ with log("Loading dependency tree..."):
+ process = subprocess.Popen(
+ ["nix-store", "--query", "--graphml", drv], stdout=subprocess.PIPE)
+ parser = ParseNixStoreQueryGraphML()
+ xml.sax.parse(process.stdout, parser)
+ assert process.wait() == 0
+ return parser
+
+
+def getDrvInfo(drv):
+ with log("Loading %s..." % drv):
+ process = subprocess.Popen(["nix",
+ "--experimental-features",
+ "nix-command",
+ "show-derivation",
+ "/nix/store/" + drv],
+ stdout=subprocess.PIPE)
+ info = json.load(process.stdout)
+ assert process.wait() == 0
+ return info
+
+
+def allBuilt(drv):
+ # TODO: How to pin outputs one at a time?
+ # Currently, we only pin when all outputs are available.
+ # It would be better to pin the outputs we've got.
+ return all(os.path.exists(o['path']) for d in getDrvInfo(
+ drv).values() for o in d['outputs'].values())
+
+
+def isDrv(thing):
+ return thing.endswith(".drv")
+
+
+def removesuffix(s, suf):
+ return s[:-len(suf)] if s.endswith(suf) else s
+
+
+def pin(drv):
+ outPath = os.path.join(sys.argv[2], removesuffix(drv, ".drv"))
+ if not os.path.exists(outPath):
+ process = subprocess.run(["nix-store",
+ "--realise",
+ "--add-root",
+ outPath,
+ "/nix/store/" + drv],
+ check=True)
+
+
+def pinBuiltThings(thing, done, deps):
+ if thing in done:
+ return
+ done.add(thing)
+ if not isDrv(thing) or allBuilt(thing):
+ pin(thing)
+ return
+ for dep in deps.get(thing, []):
+ pinBuiltThings(dep, done, deps)
+
+
+dep_graph = getDeps(sys.argv[1])
+for root in dep_graph.roots:
+ pinBuiltThings(root, set(), dep_graph.deps)