diff options
Diffstat (limited to 'src/libutil/args.cc')
-rw-r--r-- | src/libutil/args.cc | 55 |
1 files changed, 53 insertions, 2 deletions
diff --git a/src/libutil/args.cc b/src/libutil/args.cc index f829415d1..ceaabcfca 100644 --- a/src/libutil/args.cc +++ b/src/libutil/args.cc @@ -13,6 +13,19 @@ void Args::addFlag(Flag && flag_) if (flag->shortName) shortFlags[flag->shortName] = flag; } +std::shared_ptr<std::set<std::string>> completions; + +std::string completionMarker = "___COMPLETE___"; + +std::optional<std::string> needsCompletion(std::string_view s) +{ + if (!completions) return {}; + auto i = s.find(completionMarker); + if (i != std::string::npos) + return std::string(s.begin(), i); + return {}; +} + void Args::parseCmdline(const Strings & _cmdline) { Strings pendingArgs; @@ -20,6 +33,13 @@ void Args::parseCmdline(const Strings & _cmdline) Strings cmdline(_cmdline); + if (auto s = getEnv("NIX_GET_COMPLETIONS")) { + size_t n = std::stoi(*s); + assert(n > 0 && n <= cmdline.size()); + *std::next(cmdline.begin(), n - 1) += completionMarker; + completions = std::make_shared<decltype(completions)::element_type>(); + } + for (auto pos = cmdline.begin(); pos != cmdline.end(); ) { auto arg = *pos; @@ -99,18 +119,32 @@ bool Args::processFlag(Strings::iterator & pos, Strings::iterator end) auto process = [&](const std::string & name, const Flag & flag) -> bool { ++pos; std::vector<std::string> args; + bool anyNeedsCompletion = false; for (size_t n = 0 ; n < flag.handler.arity; ++n) { if (pos == end) { if (flag.handler.arity == ArityAny) break; - throw UsageError("flag '%s' requires %d argument(s)", name, flag.handler.arity); + if (anyNeedsCompletion) + args.push_back(""); + else + throw UsageError("flag '%s' requires %d argument(s)", name, flag.handler.arity); + } else { + if (needsCompletion(*pos)) + anyNeedsCompletion = true; + args.push_back(*pos++); } - args.push_back(*pos++); } flag.handler.fun(std::move(args)); return true; }; if (string(*pos, 0, 2) == "--") { + if (auto prefix = needsCompletion(*pos)) { + for (auto & [name, flag] : longFlags) { + if (!hiddenCategories.count(flag->category) + && hasPrefix(name, std::string(*prefix, 2))) + completions->insert("--" + name); + } + } auto i = longFlags.find(string(*pos, 2)); if (i == longFlags.end()) return false; return process("--" + i->first, *i->second); @@ -123,6 +157,14 @@ bool Args::processFlag(Strings::iterator & pos, Strings::iterator end) return process(std::string("-") + c, *i->second); } + if (auto prefix = needsCompletion(*pos)) { + if (prefix == "-") { + completions->insert("--"); + for (auto & [flag, _] : shortFlags) + completions->insert(std::string("-") + flag); + } + } + return false; } @@ -161,6 +203,10 @@ Args::Flag Args::Flag::mkHashTypeFlag(std::string && longName, HashType * ht) .description = "hash algorithm ('md5', 'sha1', 'sha256', or 'sha512')", .labels = {"hash-algo"}, .handler = {[ht](std::string s) { + if (auto prefix = needsCompletion(s)) + for (auto & type : hashTypes) + if (hasPrefix(type, *prefix)) + completions->insert(type); *ht = parseHashType(s); if (*ht == htUnknown) throw UsageError("unknown hash type '%1%'", s); @@ -217,6 +263,11 @@ MultiCommand::MultiCommand(const Commands & commands) { expectedArgs.push_back(ExpectedArg{"command", 1, true, [=](std::vector<std::string> ss) { assert(!command); + if (auto prefix = needsCompletion(ss[0])) { + for (auto & [name, command] : commands) + if (hasPrefix(name, *prefix)) + completions->insert(name); + } auto i = commands.find(ss[0]); if (i == commands.end()) throw UsageError("'%s' is not a recognised command", ss[0]); |