]> git.scottworley.com Git - rfc1751/blobdiff - rfc1751.py
decode support
[rfc1751] / rfc1751.py
index 616ea02e60c9e793bfe83aba7c0e42622d725688..70f5b3a64e469c0faed2ad37509474989f991368 100644 (file)
@@ -18,24 +18,31 @@ import sys
 from rfc1751wordlist import WORD_LIST, WORD_LIST_SIZE
 
 
-# TODO: Decode
+WORD_LIST_INVERTED = {word: i for (i, word) in enumerate(WORD_LIST)}
+assert len(WORD_LIST_INVERTED) == WORD_LIST_SIZE
 
 
-def encode(x: int) -> str:
-    return WORD_LIST[x]
+def encode_actual(x: int) -> list[str]:
+    return [] if x <= 0 else encode_actual(
+        x // WORD_LIST_SIZE) + [WORD_LIST[x % WORD_LIST_SIZE]]
 
 
-def rfc1751_actual(x: int) -> list[str]:
-    return [] if x <= 0 else rfc1751_actual(
-        x // WORD_LIST_SIZE) + [encode(x % WORD_LIST_SIZE)]
+def encode(x: int) -> list[str]:
+    return [WORD_LIST[x]] if x == 0 else encode_actual(x)
 
 
-def rfc1751(x: int) -> list[str]:
-    return [WORD_LIST[x]] if x == 0 else rfc1751_actual(x)
+def decode(x: list[str]) -> int:
+    return WORD_LIST_SIZE * decode(x[:-1]) + \
+        WORD_LIST_INVERTED[x[-1]] if x else 0
 
 
 def main() -> None:
-    print(' '.join(rfc1751(int(sys.argv[1]))))
+    if sys.argv[1].isnumeric():
+        print(' '.join(encode(int(sys.argv[1]))))
+    elif len(sys.argv) == 2:
+        print(decode(sys.argv[1].split(' ')))
+    else:
+        print(decode(sys.argv[1:]))
 
 
 if __name__ == '__main__':