From ab2fc45dbde41e905057b475701fbf0346e087e3 Mon Sep 17 00:00:00 2001 From: Benjamin Morgan Date: Fri, 10 Mar 2023 10:45:31 +0100 Subject: [PATCH 01/22] engine: Remove --distinct option from check command --- engine/src/main.cpp | 1 - engine/src/main_check.hpp | 44 +-------------------------------------- 2 files changed, 1 insertion(+), 44 deletions(-) diff --git a/engine/src/main.cpp b/engine/src/main.cpp index 04fd65084..87e78fc68 100644 --- a/engine/src/main.cpp +++ b/engine/src/main.cpp @@ -69,7 +69,6 @@ int main(int argc, char** argv) { engine::CheckOptions check_options; std::vector check_files; auto check = app.add_subcommand("check", "Validate stack file configurations."); - check->add_flag("-d,--distinct", check_options.distinct, "Validate each file distinctly"); check->add_flag("-s,--summarize", check_options.summarize, "Summarize results"); check->add_flag("-j,--json", check_options.output_json, "Output results as JSON data"); check->add_option("-J,--json-indent", check_options.json_indent, "JSON indentation level"); diff --git a/engine/src/main_check.hpp b/engine/src/main_check.hpp index 3dd4ca477..57693f3ba 100644 --- a/engine/src/main_check.hpp +++ b/engine/src/main_check.hpp @@ -38,7 +38,6 @@ struct CheckOptions { std::string delimiter = ","; // Flags: - bool distinct = false; bool summarize = false; bool output_json = false; int json_indent = 2; @@ -129,49 +128,8 @@ inline int check_merged(const CheckOptions& opt, const std::vector& return ok ? EXIT_SUCCESS : EXIT_FAILURE; } -inline int check_distinct(const CheckOptions& opt, const std::vector& filepaths) { - int exit_code = EXIT_SUCCESS; - auto check_each = [&](std::function func) { - for (const auto& x : filepaths) { - bool ok = true; - func(x, &ok); - if (!ok) { - exit_code = EXIT_FAILURE; - } - } - }; - - if (opt.output_json) { - // Output for each file a summary - cloe::Json output; - check_each([&](const auto& f, bool* ok) { - output[f] = check_json(opt, std::vector{f}, ok); - }); - opt.output << output.dump(opt.json_indent) << std::endl; - } else if (opt.summarize) { - check_each([&](const auto& f, bool* ok) { - opt.output << f << ": " << check_summary(opt, std::vector{f}, ok) << std::endl; - }); - } else { - check_each([&](const auto& f, bool* ok) { - try { - check_stack(opt.stack_options, std::vector{f}, ok); - } catch (cloe::ConcludedError&) { - } catch (std::exception& e) { - opt.output << f << ": " << e.what() << std::endl; - } - }); - } - - return exit_code; -} - inline int check(const CheckOptions& opt, const std::vector& filepaths) { - if (opt.distinct) { - return check_distinct(opt, filepaths); - } else { - return check_merged(opt, filepaths); - } + return check_merged(opt, filepaths); } } // namespace engine From 9a7a556194f1dbe7412ec4fa4f6c9472cbc9ffaa Mon Sep 17 00:00:00 2001 From: Benjamin Morgan Date: Thu, 16 Mar 2023 10:24:13 +0100 Subject: [PATCH 02/22] engine: Set stack options in own scope --- engine/src/main.cpp | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/engine/src/main.cpp b/engine/src/main.cpp index 87e78fc68..fdcc024f8 100644 --- a/engine/src/main.cpp +++ b/engine/src/main.cpp @@ -149,19 +149,22 @@ int main(int argc, char** argv) { } // Setup stack, applying strict/secure mode if necessary, and provide launch command. - if (stack_options.secure_mode) { - stack_options.strict_mode = true; - stack_options.no_hooks = true; - stack_options.interpolate_vars = false; - } - if (stack_options.strict_mode) { - stack_options.no_system_plugins = true; - stack_options.no_system_confs = true; - run_options.require_success = true; + { + if (stack_options.secure_mode) { + stack_options.strict_mode = true; + stack_options.no_hooks = true; + stack_options.interpolate_vars = false; + } + if (stack_options.strict_mode) { + stack_options.no_system_plugins = true; + stack_options.no_system_confs = true; + run_options.require_success = true; + } + stack_options.environment->prefer_external(false); + stack_options.environment->allow_undefined(stack_options.interpolate_undefined); + stack_options.environment->insert(CLOE_SIMULATION_UUID_VAR, "${" CLOE_SIMULATION_UUID_VAR "}"); } - stack_options.environment->prefer_external(false); - stack_options.environment->allow_undefined(stack_options.interpolate_undefined); - stack_options.environment->insert(CLOE_SIMULATION_UUID_VAR, "${" CLOE_SIMULATION_UUID_VAR "}"); + auto with_stack_options = [&](auto& opt) -> decltype(opt) { opt.stack_options = stack_options; return opt; From 87d82bcb652364b0d068b1b50a597a15f7c7a8a4 Mon Sep 17 00:00:00 2001 From: Benjamin Morgan Date: Wed, 10 May 2023 14:33:12 +0200 Subject: [PATCH 03/22] engine: Extract SimulationProgress into simulation_progress.hpp --- engine/CMakeLists.txt | 2 + engine/src/simulation_context.hpp | 137 +---------------------- engine/src/simulation_progress.hpp | 168 +++++++++++++++++++++++++++++ 3 files changed, 171 insertions(+), 136 deletions(-) create mode 100644 engine/src/simulation_progress.hpp diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index abf653f82..0ffff85a6 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -85,6 +85,8 @@ add_executable(cloe-engine src/coordinator.cpp src/simulation.cpp src/simulation_context.cpp + src/simulation_context.hpp + src/simulation_progress.hpp src/utility/command.cpp ) set_target_properties(cloe-engine PROPERTIES diff --git a/engine/src/simulation_context.hpp b/engine/src/simulation_context.hpp index ebe97765a..d3e81b125 100644 --- a/engine/src/simulation_context.hpp +++ b/engine/src/simulation_context.hpp @@ -45,8 +45,8 @@ #include "registrar.hpp" // for Registrar #include "server.hpp" // for Server #include "stack.hpp" // for Stack +#include "simulation_progress.hpp" // for SimulationProgress #include "utility/command.hpp" // for CommandExecuter -#include "utility/progress.hpp" // for Progress #include "utility/time_event.hpp" // for TimeCallback namespace engine { @@ -120,141 +120,6 @@ class SimulationSync : public cloe::Sync { cloe::Duration step_width_{20'000'000}; // should be 20ms }; -/** - * SimulationProgress represents the progress of the simulation, split into - * initialization and execution phases. - */ -class SimulationProgress { - using TimePoint = std::chrono::steady_clock::time_point; - - public: - std::string stage{""}; - std::string message{"initializing engine"}; - - Progress initialization; - size_t initialization_n; - size_t initialization_k; - - Progress execution; - cloe::Duration execution_eta{0}; - - // Reporting: - double report_granularity_p{0.1}; - cloe::Duration report_granularity_d{10'000'000'000}; - double execution_report_p; - TimePoint execution_report_t; - - public: - void init_begin(size_t n) { - message = "initializing"; - initialization.begin(); - initialization_n = n; - initialization_k = 0; - } - - void init(const std::string& what) { - stage = what; - message = "initializing " + what; - initialization_k++; - double p = static_cast(initialization_k) / static_cast(initialization_n); - initialization.update(p); - } - - void init_end() { - initialization_k++; - assert(initialization_k == initialization_n); - initialization.end(); - stage = ""; - message = "initialization done"; - } - - bool is_init_ended() const { return initialization.is_ended(); } - - cloe::Duration elapsed() const { - if (is_init_ended()) { - return initialization.elapsed() + execution.elapsed(); - } else { - return initialization.elapsed(); - } - } - - void exec_begin() { - stage = "simulation"; - message = "executing simulation"; - execution_report_p = 0; - execution_report_t = std::chrono::steady_clock::now(); - execution.begin(); - } - - void exec_update(double p) { execution.update_safe(p); } - - void exec_update(cloe::Duration now) { - if (execution_eta != cloe::Duration(0)) { - double now_d = static_cast(now.count()); - double eta_d = static_cast(execution_eta.count()); - exec_update(now_d / eta_d); - } - } - - void exec_end() { - stage = ""; - message = "simulation done"; - execution.end(); - } - - bool is_exec_ended() const { return execution.is_ended(); } - - /** - * Return true and store the current progress percentage and time if the - * current percentage is granularity_p ahead or at least granularity_d has - * elapsed since the last report. - */ - bool exec_report() { - // We should not report 100% more than once. - if (execution_report_p == 1.0) { - return false; - } - - // If there is no execution ETA, also don't report. - if (execution_eta == cloe::Duration(0)) { - return false; - } - - // Certain minimum percentage has passed. - auto now = std::chrono::steady_clock::now(); - if (execution.is_ended()) { - // We should report 100% at least once. - execution_report_p = 1.0; - execution_report_t = now; - return true; - } else if (execution.percent() - execution_report_p > report_granularity_p) { - // We should report at least every report_granularity_p (percent). - execution_report_p = execution.percent(); - execution_report_t = now; - return true; - } else if (cast_duration(now - execution_report_t) > report_granularity_d) { - // We should report at least every report_granularity_d (duration). - execution_report_p = execution.percent(); - execution_report_t = now; - return true; - } else { - return false; - } - } - - friend void to_json(cloe::Json& j, const SimulationProgress& p) { - j = cloe::Json{ - {"message", p.message}, - {"initialization", p.initialization}, - }; - if (p.execution_eta > cloe::Duration(0)) { - j["execution"] = p.execution; - } else { - j["execution"] = nullptr; - } - } -}; - struct SimulationStatistics { cloe::utility::Accumulator engine_time_ms; cloe::utility::Accumulator cycle_time_ms; diff --git a/engine/src/simulation_progress.hpp b/engine/src/simulation_progress.hpp new file mode 100644 index 000000000..2568afedf --- /dev/null +++ b/engine/src/simulation_progress.hpp @@ -0,0 +1,168 @@ +/* + * Copyright 2020 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file simulation_progress.hpp + */ + +#pragma once + +#include +#include + +#include + +#include "utility/progress.hpp" // for Progress + +namespace engine { + +/** + * SimulationProgress represents the progress of the simulation, split into + * initialization and execution phases. + */ +class SimulationProgress { + using TimePoint = std::chrono::steady_clock::time_point; + + public: + std::string stage{""}; + std::string message{"initializing engine"}; + + Progress initialization; + size_t initialization_n; + size_t initialization_k; + + Progress execution; + cloe::Duration execution_eta{0}; + + // Reporting: + double report_granularity_p{0.1}; + cloe::Duration report_granularity_d{10'000'000'000}; + double execution_report_p; + TimePoint execution_report_t; + + public: + void init_begin(size_t n) { + message = "initializing"; + initialization.begin(); + initialization_n = n; + initialization_k = 0; + } + + void init(const std::string& what) { + stage = what; + message = "initializing " + what; + initialization_k++; + double p = static_cast(initialization_k) / static_cast(initialization_n); + initialization.update(p); + } + + void init_end() { + initialization_k++; + assert(initialization_k == initialization_n); + initialization.end(); + stage = ""; + message = "initialization done"; + } + + bool is_init_ended() const { return initialization.is_ended(); } + + cloe::Duration elapsed() const { + if (is_init_ended()) { + return initialization.elapsed() + execution.elapsed(); + } else { + return initialization.elapsed(); + } + } + + void exec_begin() { + stage = "simulation"; + message = "executing simulation"; + execution_report_p = 0; + execution_report_t = std::chrono::steady_clock::now(); + execution.begin(); + } + + void exec_update(double p) { execution.update_safe(p); } + + void exec_update(cloe::Duration now) { + if (execution_eta != cloe::Duration(0)) { + double now_d = static_cast(now.count()); + double eta_d = static_cast(execution_eta.count()); + exec_update(now_d / eta_d); + } + } + + void exec_end() { + stage = ""; + message = "simulation done"; + execution.end(); + } + + bool is_exec_ended() const { return execution.is_ended(); } + + /** + * Return true and store the current progress percentage and time if the + * current percentage is granularity_p ahead or at least granularity_d has + * elapsed since the last report. + */ + bool exec_report() { + // We should not report 100% more than once. + if (execution_report_p == 1.0) { + return false; + } + + // If there is no execution ETA, also don't report. + if (execution_eta == cloe::Duration(0)) { + return false; + } + + // Certain minimum percentage has passed. + auto now = std::chrono::steady_clock::now(); + if (execution.is_ended()) { + // We should report 100% at least once. + execution_report_p = 1.0; + execution_report_t = now; + return true; + } else if (execution.percent() - execution_report_p > report_granularity_p) { + // We should report at least every report_granularity_p (percent). + execution_report_p = execution.percent(); + execution_report_t = now; + return true; + } else if (cast_duration(now - execution_report_t) > report_granularity_d) { + // We should report at least every report_granularity_d (duration). + execution_report_p = execution.percent(); + execution_report_t = now; + return true; + } else { + return false; + } + } + + friend void to_json(cloe::Json& j, const SimulationProgress& p) { + j = cloe::Json{ + {"message", p.message}, + {"initialization", p.initialization}, + }; + if (p.execution_eta > cloe::Duration(0)) { + j["execution"] = p.execution; + } else { + j["execution"] = nullptr; + } + } +}; + +} // namespace engine From 93b4d2425f1e84a46133832af3f049f13e321443 Mon Sep 17 00:00:00 2001 From: Benjamin Morgan Date: Thu, 16 May 2024 23:49:52 +0200 Subject: [PATCH 04/22] runtime: Add CallbackResult return value to all action operator() BREAKING CHANGE: If you have implemented any Actions yourself, you will need to make sure the `operator()(const Sync& ...)` returns `cloe::CallbackResult`. You can use `CallbackResult::Ok` to achieve the same behavior. --- engine/src/coordinator.cpp | 58 ++++--- engine/src/coordinator.hpp | 30 ++-- engine/src/utility/command.cpp | 2 +- engine/src/utility/command.hpp | 2 +- optional/vtd/src/scp_action.hpp | 3 +- plugins/basic/src/hmi_contact.hpp | 2 +- runtime/include/cloe/conf/action.hpp | 5 +- runtime/include/cloe/registrar.hpp | 5 +- runtime/include/cloe/trigger.hpp | 14 +- .../include/cloe/trigger/example_actions.hpp | 11 +- runtime/include/cloe/trigger/set_action.hpp | 161 +++++++++--------- runtime/src/cloe/trigger.cpp | 4 +- runtime/src/cloe/trigger/example_actions.cpp | 25 +-- 13 files changed, 182 insertions(+), 140 deletions(-) diff --git a/engine/src/coordinator.cpp b/engine/src/coordinator.cpp index 844261f23..10bc74bad 100644 --- a/engine/src/coordinator.cpp +++ b/engine/src/coordinator.cpp @@ -62,6 +62,8 @@ class TriggerRegistrar : public cloe::TriggerRegistrar { return manager_.make_trigger(source_, c); } + // TODO: Should these queue_trigger becomes inserts? Because if they are coming from + // C++ then they should be from a single thread. void insert_trigger(const Conf& c) override { manager_.queue_trigger(source_, c); } void insert_trigger(TriggerPtr&& t) override { manager_.queue_trigger(std::move(t)); } @@ -168,44 +170,54 @@ void Coordinator::register_event(const std::string& key, EventFactoryPtr&& ef, std::bind(&Coordinator::execute_trigger, this, std::placeholders::_1, std::placeholders::_2)); } -void Coordinator::execute_trigger(TriggerPtr&& t, const Sync& sync) { +cloe::CallbackResult Coordinator::execute_trigger(TriggerPtr&& t, const Sync& sync) { logger()->debug("Execute trigger {}", inline_json(*t)); - (t->action())(sync, *executer_registrar_); + auto result = (t->action())(sync, *executer_registrar_); if (!t->is_conceal()) { - history_.emplace_back(HistoryTrigger{sync.time(), std::move(t)}); + history_.emplace_back(sync.time(), std::move(t)); } + return result; } -Duration Coordinator::process(const Sync& sync) { +size_t Coordinator::process_pending_web_triggers(const Sync& sync) { // The only thing we need to do here is distribute the triggers from the // input queue into their respective storage locations. We are responsible // for thread safety here! - auto now = sync.time(); + size_t count = 0; std::unique_lock guard(input_mutex_); while (!input_queue_.empty()) { - auto tp = std::move(input_queue_.front()); + store_trigger(std::move(input_queue_.front()), sync); input_queue_.pop_front(); - tp->set_since(now); + count++; + } + return count; +} - // Decide where to put the trigger - auto key = tp->event().name(); - if (storage_.count(key) == 0) { - // This is a programming error, since we should not be able to come this - // far at all. - throw std::logic_error("cannot insert trigger with unregistered event"); - } - try { - logger()->debug("Insert trigger {}", inline_json(*tp)); - storage_[key]->emplace(std::move(tp), sync); - } catch (TriggerError& e) { - logger()->error("Error inserting trigger: {}", e.what()); - if (!allow_errors_) { - throw; - } +void Coordinator::store_trigger(TriggerPtr&& tp, const Sync& sync) { + tp->set_since(sync.time()); + + // Decide where to put the trigger + auto key = tp->event().name(); + if (storage_.count(key) == 0) { + // This is a programming error, since we should not be able to come this + // far at all. + throw std::logic_error("cannot insert trigger with unregistered event"); + } + try { + logger()->debug("Insert trigger {}", inline_json(*tp)); + storage_[key]->emplace(std::move(tp), sync); + } catch (TriggerError& e) { + logger()->error("Error inserting trigger: {}", e.what()); + if (!allow_errors_) { + throw; } } +} - return now; +Duration Coordinator::process(const Sync& sync) { + auto now = sync.time(); + process_pending_web_triggers(sync); + return sync.time(); } namespace { diff --git a/engine/src/coordinator.hpp b/engine/src/coordinator.hpp index 00a538d70..a52452fd2 100644 --- a/engine/src/coordinator.hpp +++ b/engine/src/coordinator.hpp @@ -22,25 +22,21 @@ #pragma once -#include // for list<> -#include // for map<> -#include // for unique_ptr<>, shared_ptr<> -#include // for mutex -#include // for queue<> -#include // for string -#include // for vector<> +#include // for list<> +#include // for map<> +#include // for unique_ptr<>, shared_ptr<> +#include // for mutex +#include // for queue<> +#include // for string +#include // for vector<> -#include // for Trigger, Action, Event, ... - -// Forward declaration: -namespace cloe { -class Registrar; -} +#include // for Registrar +#include // for Trigger, Action, Event, ... namespace engine { // Forward declarations: -class TriggerRegistrar; // from trigger_manager.cpp +class TriggerRegistrar; // from coordinator.cpp /** * TriggerUnknownAction is thrown when an Action cannot be created because the @@ -120,13 +116,15 @@ class Coordinator { */ cloe::Duration process(const cloe::Sync&); + size_t process_pending_web_triggers(const cloe::Sync& sync); protected: cloe::ActionPtr make_action(const cloe::Conf& c) const; cloe::EventPtr make_event(const cloe::Conf& c) const; cloe::TriggerPtr make_trigger(cloe::Source s, const cloe::Conf& c) const; void queue_trigger(cloe::Source s, const cloe::Conf& c) { queue_trigger(make_trigger(s, c)); } - void queue_trigger(cloe::TriggerPtr&& t); - void execute_trigger(cloe::TriggerPtr&& t, const cloe::Sync& s); + void queue_trigger(cloe::TriggerPtr&& tp); + void store_trigger(cloe::TriggerPtr&& tp, const cloe::Sync& sync); + cloe::CallbackResult execute_trigger(cloe::TriggerPtr&& tp, const cloe::Sync& sync); // for access to protected methods friend TriggerRegistrar; diff --git a/engine/src/utility/command.cpp b/engine/src/utility/command.cpp index 6de1a7586..89596e549 100644 --- a/engine/src/utility/command.cpp +++ b/engine/src/utility/command.cpp @@ -128,7 +128,7 @@ void CommandExecuter::wait_all() { namespace actions { -void Command::operator()(const cloe::Sync&, cloe::TriggerRegistrar&) { executer_->run(command_); } +cloe::CallbackResult Command::operator()(const cloe::Sync&, cloe::TriggerRegistrar&) { executer_->run(command_); return cloe::CallbackResult::Ok; } void Command::to_json(cloe::Json& j) const { command_.to_json(j); } diff --git a/engine/src/utility/command.hpp b/engine/src/utility/command.hpp index eb4b4c373..41789aca4 100644 --- a/engine/src/utility/command.hpp +++ b/engine/src/utility/command.hpp @@ -86,7 +86,7 @@ class Command : public cloe::Action { return std::make_unique(name(), command_, executer_); } - void operator()(const cloe::Sync&, cloe::TriggerRegistrar&) override; + cloe::CallbackResult operator()(const cloe::Sync&, cloe::TriggerRegistrar&) override; protected: void to_json(cloe::Json& j) const override; diff --git a/optional/vtd/src/scp_action.hpp b/optional/vtd/src/scp_action.hpp index 930a76320..1184f63e2 100644 --- a/optional/vtd/src/scp_action.hpp +++ b/optional/vtd/src/scp_action.hpp @@ -35,9 +35,10 @@ class ScpAction : public cloe::Action { ScpAction(const std::string& name, std::shared_ptr scp_client, const std::string& msg) : cloe::Action(name), client_(scp_client), xml_(msg) {} cloe::ActionPtr clone() const override { return std::make_unique(name(), client_, xml_); } - void operator()(const cloe::Sync&, cloe::TriggerRegistrar&) override { + cloe::CallbackResult operator()(const cloe::Sync&, cloe::TriggerRegistrar&) override { logger()->info("Sending SCP message: {}", xml_); client_->send(xml_); + return cloe::CallbackResult::Ok; } bool is_significant() const override { return false; } diff --git a/plugins/basic/src/hmi_contact.hpp b/plugins/basic/src/hmi_contact.hpp index 5a0b166a5..ef3f132ec 100644 --- a/plugins/basic/src/hmi_contact.hpp +++ b/plugins/basic/src/hmi_contact.hpp @@ -207,7 +207,7 @@ class UseContact : public Action { UseContact(const std::string& name, ContactMap* m, const Conf& data) : Action(name), hmi_(m), data_(data) {} ActionPtr clone() const override { return std::make_unique>(name(), hmi_, data_); } - void operator()(const Sync&, TriggerRegistrar&) override { from_json(*data_, *hmi_); } + CallbackResult operator()(const Sync&, TriggerRegistrar&) override { from_json(*data_, *hmi_); return CallbackResult::Ok; } protected: void to_json(Json& j) const override { j = *data_; } diff --git a/runtime/include/cloe/conf/action.hpp b/runtime/include/cloe/conf/action.hpp index 81c4eed3b..ee8966fc4 100644 --- a/runtime/include/cloe/conf/action.hpp +++ b/runtime/include/cloe/conf/action.hpp @@ -39,7 +39,10 @@ class Configure : public Action { conf_.erase("name"); } ActionPtr clone() const override { return std::make_unique(name(), ptr_, conf_); } - void operator()(const Sync&, TriggerRegistrar&) override { ptr_->from_conf(conf_); } + CallbackResult operator()(const Sync&, TriggerRegistrar&) override { + ptr_->from_conf(conf_); + return CallbackResult::Ok; + } protected: void to_json(Json& j) const override { j = *conf_; } diff --git a/runtime/include/cloe/registrar.hpp b/runtime/include/cloe/registrar.hpp index 2cf4f7bc1..f96065318 100644 --- a/runtime/include/cloe/registrar.hpp +++ b/runtime/include/cloe/registrar.hpp @@ -55,7 +55,10 @@ class DirectCallback : public Callback { auto& condition = dynamic_cast((*it)->event()); if (condition(sync, args...)) { if ((*it)->is_sticky()) { - this->execute((*it)->clone(), sync); + auto result = this->execute((*it)->clone(), sync); + if (result == CallbackResult::Unpin) { + it = triggers_.erase(it); + } } else { // Remove from trigger list and advance. this->execute(std::move(*it), sync); diff --git a/runtime/include/cloe/trigger.hpp b/runtime/include/cloe/trigger.hpp index 17ad22f13..af832fc72 100644 --- a/runtime/include/cloe/trigger.hpp +++ b/runtime/include/cloe/trigger.hpp @@ -507,10 +507,18 @@ class Event : public Entity { using EventFactory = TriggerFactory; using EventFactoryPtr = std::unique_ptr; +enum class CallbackResult { + /// Default, use standard behavior. + Ok, + + /// Remove from callback if it was sticky. + Unpin, +}; + /** * Interface the trigger manager must provide for executing triggers. */ -using CallbackExecuter = std::function; +using CallbackExecuter = std::function; /** * Callback provides the interface with which the global trigger manager, @@ -555,7 +563,7 @@ class Callback { * Execute a trigger in the given sync context by passing it to the * executer. */ - void execute(TriggerPtr&& t, const Sync& s); + CallbackResult execute(TriggerPtr&& t, const Sync& s); private: CallbackExecuter executer_; @@ -619,7 +627,7 @@ class Action : public Entity { /** * Execute the action. */ - virtual void operator()(const Sync&, TriggerRegistrar&) = 0; + virtual CallbackResult operator()(const Sync&, TriggerRegistrar&) = 0; /** * Return whether this action is a significant action. diff --git a/runtime/include/cloe/trigger/example_actions.hpp b/runtime/include/cloe/trigger/example_actions.hpp index 3161c26e4..cacded8d6 100644 --- a/runtime/include/cloe/trigger/example_actions.hpp +++ b/runtime/include/cloe/trigger/example_actions.hpp @@ -41,7 +41,10 @@ class Log : public Action { Log(const std::string& name, LogLevel level, const std::string& msg) : Action(name), level_(level), msg_(msg) {} ActionPtr clone() const override { return std::make_unique(name(), level_, msg_); } - void operator()(const Sync&, TriggerRegistrar&) override { logger()->log(level_, msg_.c_str()); } + CallbackResult operator()(const Sync&, TriggerRegistrar&) override { + logger()->log(level_, msg_.c_str()); + return CallbackResult::Ok; + } bool is_significant() const override { return false; } protected: @@ -71,7 +74,7 @@ class Bundle : public Action { public: Bundle(const std::string& name, std::vector&& actions); ActionPtr clone() const override; - void operator()(const Sync& s, TriggerRegistrar& r) override; + CallbackResult operator()(const Sync& s, TriggerRegistrar& r) override; bool is_significant() const override; protected: @@ -103,7 +106,7 @@ class Insert : public Action { public: Insert(const std::string& name, const Conf& triggers) : Action(name), triggers_(triggers) {} ActionPtr clone() const override { return std::make_unique(name(), triggers_); } - void operator()(const Sync& s, TriggerRegistrar& r) override; + CallbackResult operator()(const Sync& s, TriggerRegistrar& r) override; protected: void to_json(Json& j) const override; @@ -138,7 +141,7 @@ class PushRelease : public Action { return std::make_unique(name(), duration_, push_->clone(), release_->clone(), repr_); } - void operator()(const Sync&, TriggerRegistrar&) override; + CallbackResult operator()(const Sync&, TriggerRegistrar&) override; protected: void to_json(Json& j) const override { j = repr_; } diff --git a/runtime/include/cloe/trigger/set_action.hpp b/runtime/include/cloe/trigger/set_action.hpp index c65cd699f..4f07072e9 100644 --- a/runtime/include/cloe/trigger/set_action.hpp +++ b/runtime/include/cloe/trigger/set_action.hpp @@ -68,7 +68,10 @@ class SetVariableAction : public Action { ActionPtr clone() const override { return std::make_unique(name(), data_name_, data_ptr_, value_); } - void operator()(const Sync&, TriggerRegistrar&) override { *data_ptr_ = value_; } + CallbackResult operator()(const Sync&, TriggerRegistrar&) override { + *data_ptr_ = value_; + return CallbackResult::Ok; + } bool is_significant() const override { return false; } void to_json(Json& j) const override { j = Json{ @@ -138,34 +141,37 @@ class SetVariableActionFactory : public ActionFactory { * * This action can be registered with the `register_action` helper function. */ -#define DEFINE_SET_STATE_ACTION(xName, xname, xdescription, xState, xOperatorBlock) \ - class xName : public ::cloe::Action { \ - public: \ - xName(const std::string& name, xState* ptr) : ::cloe::Action(name), ptr_(ptr) {} \ - ::cloe::ActionPtr clone() const override { return std::make_unique(name(), ptr_); } \ - void operator()(const ::cloe::Sync&, ::cloe::TriggerRegistrar&) override { xOperatorBlock } \ - void to_json(::cloe::Json&) const override {} \ - \ - private: \ - xState* ptr_; \ - }; \ - \ - class _X_FACTORY(xName) : public ::cloe::ActionFactory { \ - public: \ - using ActionType = xName; \ - \ - _X_FACTORY(xName)(xState * ptr) : ::cloe::ActionFactory(xname, xdescription), ptr_(ptr) {} \ - \ - ::cloe::ActionPtr make(const ::cloe::Conf&) const override { \ - return std::make_unique(name(), ptr_); \ - } \ - \ - ::cloe::ActionPtr make(const std::string&) const override { \ - return std::make_unique(name(), ptr_); \ - } \ - \ - private: \ - xState* ptr_; \ +#define DEFINE_SET_STATE_ACTION(xName, xname, xdescription, xState, xOperatorBlock) \ + class xName : public ::cloe::Action { \ + public: \ + xName(const std::string& name, xState* ptr) : ::cloe::Action(name), ptr_(ptr) {} \ + ::cloe::ActionPtr clone() const override { return std::make_unique(name(), ptr_); } \ + ::cloe::CallbackResult operator()(const ::cloe::Sync&, ::cloe::TriggerRegistrar&) override { \ + xOperatorBlock; \ + return ::cloe::CallbackResult::Ok; \ + } \ + void to_json(::cloe::Json&) const override {} \ + \ + private: \ + xState* ptr_; \ + }; \ + \ + class _X_FACTORY(xName) : public ::cloe::ActionFactory { \ + public: \ + using ActionType = xName; \ + \ + _X_FACTORY(xName)(xState * ptr) : ::cloe::ActionFactory(xname, xdescription), ptr_(ptr) {} \ + \ + ::cloe::ActionPtr make(const ::cloe::Conf&) const override { \ + return std::make_unique(name(), ptr_); \ + } \ + \ + ::cloe::ActionPtr make(const std::string&) const override { \ + return std::make_unique(name(), ptr_); \ + } \ + \ + private: \ + xState* ptr_; \ }; /** @@ -203,51 +209,54 @@ class SetVariableActionFactory : public ActionFactory { * This action can be registered with the `register_action` helper function. * Refer to doc/reference/actions.rst for the configuration. */ -#define DEFINE_SET_DATA_ACTION(xName, xActionName, xActionDesc, xDataType, xAttributeName, \ - xAttributeType, xOperatorBlock) \ - class xName : public ::cloe::Action { \ - public: \ - xName(const std::string& action_name, xDataType* ptr, const std::string& attribute_name, \ - const xAttributeType attribute_value) \ - : ::cloe::Action(action_name) \ - , ptr_(ptr) \ - , name_(attribute_name) \ - , value_(attribute_value) {} \ - ::cloe::ActionPtr clone() const override { \ - return std::make_unique(name(), ptr_, name_, value_); \ - } \ - void operator()(const ::cloe::Sync&, ::cloe::TriggerRegistrar&) override { xOperatorBlock } \ - bool is_significant() const override { return false; } \ - void to_json(::cloe::Json& j) const override { \ - j = ::fable::Json{ \ - {name_, value_}, \ - }; \ - } \ - \ - private: \ - xDataType* ptr_; \ - std::string name_; \ - xAttributeType value_; \ - }; \ - \ - class _X_FACTORY(xName) : public ::cloe::ActionFactory { \ - public: \ - using ActionType = xName; \ - _X_FACTORY(xName) \ - (xDataType * ptr) : ::cloe::ActionFactory(xActionName, xActionDesc), ptr_(ptr) {} \ - \ - ::cloe::ActionPtr make(const ::cloe::Conf& c) const override { \ - auto value = c.get(xAttributeName); \ - return std::make_unique(name(), ptr_, xAttributeName, value); \ - } \ - \ - ::cloe::ActionPtr make(const std::string& s) const override { \ - auto value = ::cloe::actions::from_string(s); \ - return make(::fable::Conf{::fable::Json{ \ - {xAttributeName, value}, \ - }}); \ - } \ - \ - private: \ - xDataType* ptr_; \ +#define DEFINE_SET_DATA_ACTION(xName, xActionName, xActionDesc, xDataType, xAttributeName, \ + xAttributeType, xOperatorBlock) \ + class xName : public ::cloe::Action { \ + public: \ + xName(const std::string& action_name, xDataType* ptr, const std::string& attribute_name, \ + const xAttributeType attribute_value) \ + : ::cloe::Action(action_name) \ + , ptr_(ptr) \ + , name_(attribute_name) \ + , value_(attribute_value) {} \ + ::cloe::ActionPtr clone() const override { \ + return std::make_unique(name(), ptr_, name_, value_); \ + } \ + ::cloe::CallbackResult operator()(const ::cloe::Sync&, ::cloe::TriggerRegistrar&) override { \ + xOperatorBlock; \ + return ::cloe::CallbackResult::Ok; \ + } \ + bool is_significant() const override { return false; } \ + void to_json(::cloe::Json& j) const override { \ + j = ::fable::Json{ \ + {name_, value_}, \ + }; \ + } \ + \ + private: \ + xDataType* ptr_; \ + std::string name_; \ + xAttributeType value_; \ + }; \ + \ + class _X_FACTORY(xName) : public ::cloe::ActionFactory { \ + public: \ + using ActionType = xName; \ + _X_FACTORY(xName) \ + (xDataType * ptr) : ::cloe::ActionFactory(xActionName, xActionDesc), ptr_(ptr) {} \ + \ + ::cloe::ActionPtr make(const ::cloe::Conf& c) const override { \ + auto value = c.get(xAttributeName); \ + return std::make_unique(name(), ptr_, xAttributeName, value); \ + } \ + \ + ::cloe::ActionPtr make(const std::string& s) const override { \ + auto value = ::cloe::actions::from_string(s); \ + return make(::fable::Conf{::fable::Json{ \ + {xAttributeName, value}, \ + }}); \ + } \ + \ + private: \ + xDataType* ptr_; \ }; diff --git a/runtime/src/cloe/trigger.cpp b/runtime/src/cloe/trigger.cpp index e5677e4fd..244a46330 100644 --- a/runtime/src/cloe/trigger.cpp +++ b/runtime/src/cloe/trigger.cpp @@ -108,9 +108,9 @@ void TriggerRegistrar::insert_trigger(const std::string& label, EventPtr&& e, Ac insert_trigger(std::make_unique(label, source_, std::move(e), std::move(a))); } -void Callback::execute(TriggerPtr&& t, const Sync& sync) { +CallbackResult Callback::execute(TriggerPtr&& t, const Sync& sync) { assert(executer_); - executer_(std::move(t), sync); + return executer_(std::move(t), sync); } } // namespace cloe diff --git a/runtime/src/cloe/trigger/example_actions.cpp b/runtime/src/cloe/trigger/example_actions.cpp index 6c3e1cb60..011d31c73 100644 --- a/runtime/src/cloe/trigger/example_actions.cpp +++ b/runtime/src/cloe/trigger/example_actions.cpp @@ -31,8 +31,7 @@ #include // for Sync #include // for TriggerRegistrar -namespace cloe { -namespace actions { +namespace cloe::actions { // Log ----------------------------------------------------------------------- TriggerSchema LogFactory::schema() const { @@ -54,7 +53,7 @@ ActionPtr LogFactory::make(const Conf& c) const { ActionPtr LogFactory::make(const std::string& s) const { auto level = spdlog::level::info; - auto pos = s.find(":"); + auto pos = s.find(':'); std::string msg; if (pos != std::string::npos) { try { @@ -76,7 +75,7 @@ ActionPtr LogFactory::make(const std::string& s) const { {"level", logger::to_string(level)}, {"msg", msg}, }}; - if (msg.size() == 0) { + if (msg.empty()) { throw TriggerInvalid(c, "cannot log an empty message"); } return make(c); @@ -108,11 +107,16 @@ ActionPtr Bundle::clone() const { return std::make_unique(name(), std::move(actions)); } -void Bundle::operator()(const Sync& sync, TriggerRegistrar& r) { +CallbackResult Bundle::operator()(const Sync& sync, TriggerRegistrar& r) { logger()->trace("Run action bundle"); + auto result = CallbackResult::Ok; for (auto& a : actions_) { - (*a)(sync, r); + auto ar = (*a)(sync, r); + if (ar == CallbackResult::Unpin) { + result = ar; + } } + return result; } TriggerSchema BundleFactory::schema() const { @@ -141,11 +145,12 @@ void Insert::to_json(Json& j) const { }; } -void Insert::operator()(const Sync&, TriggerRegistrar& r) { +CallbackResult Insert::operator()(const Sync& /*unused*/, TriggerRegistrar& r) { for (const auto& tc : triggers_.to_array()) { auto local = r.make_trigger(tc); r.insert_trigger(std::move(local)); } + return CallbackResult::Ok; } TriggerSchema InsertFactory::schema() const { @@ -170,7 +175,7 @@ ActionPtr InsertFactory::make(const Conf& c) const { } // PushRelease --------------------------------------------------------------- -void PushRelease::operator()(const Sync&, TriggerRegistrar& r) { +CallbackResult PushRelease::operator()(const Sync& /*unused*/, TriggerRegistrar& r) { // clang-format off r.insert_trigger( "push down button(s)", @@ -187,6 +192,7 @@ void PushRelease::operator()(const Sync&, TriggerRegistrar& r) { }}), std::move(release_) ); + return CallbackResult::Ok; // clang-format on } @@ -234,5 +240,4 @@ ActionPtr PushReleaseFactory::make(const Conf& c) const { return std::make_unique(name(), dur, create(true), create(false), repr); } -} // namespace actions -} // namespace cloe +} // namespace cloe::actions From 5ba0e6d4db8bb56278cd4fdd68e081316d396961 Mon Sep 17 00:00:00 2001 From: Benjamin Morgan Date: Thu, 16 May 2024 23:56:31 +0200 Subject: [PATCH 05/22] engine: Restructure code for better compilation times --- engine/CMakeLists.txt | 105 +++++++---- engine/src/config.hpp | 56 ++++++ engine/src/error_handler.hpp | 60 +++++++ engine/src/main.cpp | 80 ++++----- engine/src/{main_check.hpp => main_check.cpp} | 74 +++----- engine/src/main_commands.hpp | 119 +++++++++++++ engine/src/{main_dump.hpp => main_dump.cpp} | 28 +-- engine/src/{main_run.hpp => main_run.cpp} | 81 +++------ engine/src/{main_usage.hpp => main_usage.cpp} | 146 ++++++++------- .../{main_version.hpp => main_version.cpp} | 54 +++--- engine/src/simulation.cpp | 62 ++++--- engine/src/simulation.hpp | 11 +- engine/src/simulation_context.cpp | 7 +- engine/src/simulation_context.hpp | 44 +++-- engine/src/stack.cpp | 35 +++- engine/src/stack.hpp | 48 ++--- engine/src/stack_component_test.cpp | 167 ++++++++++++++++++ .../src/{main_stack.cpp => stack_factory.cpp} | 75 ++++---- .../src/{main_stack.hpp => stack_factory.hpp} | 39 +++- engine/src/stack_test.cpp | 128 -------------- 20 files changed, 845 insertions(+), 574 deletions(-) create mode 100644 engine/src/config.hpp create mode 100644 engine/src/error_handler.hpp rename engine/src/{main_check.hpp => main_check.cpp} (57%) create mode 100644 engine/src/main_commands.hpp rename engine/src/{main_dump.hpp => main_dump.cpp} (63%) rename engine/src/{main_run.hpp => main_run.cpp} (68%) rename engine/src/{main_usage.hpp => main_usage.cpp} (62%) rename engine/src/{main_version.hpp => main_version.cpp} (54%) create mode 100644 engine/src/stack_component_test.cpp rename engine/src/{main_stack.cpp => stack_factory.cpp} (74%) rename engine/src/{main_stack.hpp => stack_factory.hpp} (61%) diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 0ffff85a6..2b6d2fd63 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -24,8 +24,11 @@ set(PROJECT_GIT_REF "unknown") # Library libstack --------------------------------------------------- message(STATUS "Building cloe-stacklib library.") add_library(cloe-stacklib STATIC + src/config.hpp src/stack.hpp src/stack.cpp + src/stack_factory.hpp + src/stack_factory.cpp src/plugin.hpp src/plugin.cpp @@ -63,6 +66,7 @@ if(BUILD_TESTING) message(STATUS "Building test-stacklib executable.") add_executable(test-stacklib src/stack_test.cpp + src/stack_component_test.cpp ) set_target_properties(test-stacklib PROPERTIES CXX_STANDARD 17 @@ -78,69 +82,96 @@ if(BUILD_TESTING) gtest_add_tests(TARGET test-stacklib) endif() -# Executable --------------------------------------------------------- -add_executable(cloe-engine - src/main.cpp - src/main_stack.cpp +# Library libengine ---------------------------------------------- +message(STATUS "Building cloe-enginelib library.") +add_library(cloe-enginelib STATIC src/coordinator.cpp + src/coordinator.hpp + src/registrar.hpp + # These are added below and depend on CLOE_ENGINE_WITH_SERVER: + # src/server.cpp + # src/server_mock.cpp src/simulation.cpp + src/simulation.hpp src/simulation_context.cpp src/simulation_context.hpp src/simulation_progress.hpp src/utility/command.cpp + src/utility/command.hpp + src/utility/defer.hpp + src/utility/progress.hpp + src/utility/state_machine.hpp + src/utility/time_event.hpp ) -set_target_properties(cloe-engine PROPERTIES +add_library(cloe::enginelib ALIAS cloe-enginelib) +set_target_properties(cloe-enginelib PROPERTIES CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON - OUTPUT_NAME cloe-engine + OUTPUT_NAME engine ) -target_compile_definitions(cloe-engine - PRIVATE - CLOE_ENGINE_VERSION="${CLOE_ENGINE_VERSION}" - CLOE_ENGINE_TIMESTAMP="${CLOE_ENGINE_TIMESTAMP}" +target_compile_definitions(cloe-enginelib + PUBLIC PROJECT_SOURCE_DIR=\"${CMAKE_CURRENT_SOURCE_DIR}\" ) -target_include_directories(cloe-engine +target_include_directories(cloe-enginelib PRIVATE src ) -target_link_libraries(cloe-engine - PRIVATE - CLI11::CLI11 - cloe::models +target_link_libraries(cloe-enginelib + PUBLIC cloe::stacklib + cloe::models + cloe::runtime + fable::fable + boost::boost + Threads::Threads ) option(CLOE_ENGINE_WITH_SERVER "Enable integrated server component?" ON) if(CLOE_ENGINE_WITH_SERVER) - message(STATUS "-> Enable server component") if(CLOE_FIND_PACKAGES) find_package(cloe-oak REQUIRED QUIET) endif() - target_sources(cloe-engine - PRIVATE - src/server.cpp - ) - target_link_libraries(cloe-engine - PRIVATE - cloe::oak - ) - target_compile_definitions(cloe-engine - PRIVATE - CLOE_ENGINE_WITH_SERVER=1 - ) + target_sources(cloe-enginelib PRIVATE src/server.cpp) + target_link_libraries(cloe-enginelib PRIVATE cloe::oak) + target_compile_definitions(cloe-enginelib PUBLIC CLOE_ENGINE_WITH_SERVER=1) else() - message(STATUS "-> Disable server component") - target_sources(cloe-engine - PRIVATE - src/server_mock.cpp - ) - target_compile_definitions(cloe-engine - PRIVATE - CLOE_ENGINE_WITH_SERVER=0 - ) + target_sources(cloe-enginelib PRIVATE src/server_mock.cpp) + target_compile_definitions(cloe-enginelib PUBLIC CLOE_ENGINE_WITH_SERVER=0) endif() +# Executable --------------------------------------------------------- +message(STATUS "Building cloe-engine executable [with server=${CLOE_ENGINE_WITH_SERVER}].") +add_executable(cloe-engine + src/main.cpp + src/main_commands.hpp + src/main_check.cpp + src/main_dump.cpp + src/main_run.cpp + src/main_usage.cpp + src/main_version.cpp +) +set_target_properties(cloe-engine PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + OUTPUT_NAME cloe-engine +) +target_compile_definitions(cloe-engine + PRIVATE + CLOE_ENGINE_VERSION="${CLOE_ENGINE_VERSION}" + CLOE_ENGINE_TIMESTAMP="${CLOE_ENGINE_TIMESTAMP}" +) +target_include_directories(cloe-engine + PRIVATE + src +) +target_link_libraries(cloe-engine + PRIVATE + cloe::stacklib + cloe::enginelib + CLI11::CLI11 +) + # Installation ------------------------------------------------------- install(TARGETS cloe-engine RUNTIME diff --git a/engine/src/config.hpp b/engine/src/config.hpp new file mode 100644 index 000000000..93e56cabc --- /dev/null +++ b/engine/src/config.hpp @@ -0,0 +1,56 @@ +/* + * Copyright 2020 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file config.hpp + */ + +#pragma once + +#ifndef CLOE_CONTACT_EMAIL +#define CLOE_CONTACT_EMAIL "cloe-dev@eclipse.org" +#endif + +#ifndef CLOE_STACK_VERSION +#define CLOE_STACK_VERSION "4.1" +#endif + +#ifndef CLOE_STACK_SUPPORTED_VERSIONS +#define CLOE_STACK_SUPPORTED_VERSIONS {"4", "4.0", "4.1"} +#endif + +#ifndef CLOE_XDG_SUFFIX +#define CLOE_XDG_SUFFIX "cloe" +#endif + +#ifndef CLOE_CONFIG_HOME +#define CLOE_CONFIG_HOME "${XDG_CONFIG_HOME-${HOME}/.config}/" CLOE_XDG_SUFFIX +#endif + +#ifndef CLOE_DATA_HOME +#define CLOE_DATA_HOME "${XDG_DATA_HOME-${HOME}/.local/share}/" CLOE_XDG_SUFFIX +#endif + +#ifndef CLOE_SIMULATION_UUID_VAR +#define CLOE_SIMULATION_UUID_VAR "CLOE_SIMULATION_UUID" +#endif + +// The environment variable from which additional plugins should +// be loaded. Takes the same format as PATH. +#ifndef CLOE_PLUGIN_PATH +#define CLOE_PLUGIN_PATH "CLOE_PLUGIN_PATH" +#endif diff --git a/engine/src/error_handler.hpp b/engine/src/error_handler.hpp new file mode 100644 index 000000000..acbcd023e --- /dev/null +++ b/engine/src/error_handler.hpp @@ -0,0 +1,60 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include // for Error +#include // for ConfError, SchemaError +#include // for indent_string, pretty_print + +namespace cloe { + +inline std::string format_error(const std::exception& exception) { + std::stringstream buf; + if (const auto* err = dynamic_cast(&exception); err) { + fable::pretty_print(*err, buf); + } else if (const auto* err = dynamic_cast(&exception); err) { + fable::pretty_print(*err, buf); + } else if (const auto* err = dynamic_cast(&exception); err) { + buf << err->what() << "\n"; + if (err->has_explanation()) { + buf << " Note:\n"; + buf << fable::indent_string(err->explanation(), " "); + } + } else { + buf << exception.what(); + } + return buf.str(); +} + +template +auto conclude_error(std::ostream& out, Func f) -> decltype(f()) { + try { + return f(); + } catch (cloe::ConcludedError&) { + // Has already been logged. + throw; + } catch (std::exception& err) { + out << "Error: " << format_error(err) << std::endl; + throw cloe::ConcludedError(err); + } +} + +} // namespace cloe diff --git a/engine/src/main.cpp b/engine/src/main.cpp index fdcc024f8..695fdd997 100644 --- a/engine/src/main.cpp +++ b/engine/src/main.cpp @@ -15,69 +15,58 @@ * * SPDX-License-Identifier: Apache-2.0 */ -/** - * \file main.cpp - * \see main_check.hpp - * \see main_dump.hpp - * \see main_run.hpp - * \see main_usage.hpp - * \see main_version.hpp - */ #include // for cerr #include // for string +#include // for swap #include -#include "main_check.hpp" -#include "main_dump.hpp" -#include "main_run.hpp" -#include "main_stack.hpp" -#include "main_usage.hpp" -#include "main_version.hpp" +#include +#include -#ifndef CLOE_CONTACT_EMAIL -#define CLOE_CONTACT_EMAIL "cloe-dev@eclipse.org" -#endif +#include "config.hpp" +#include "main_commands.hpp" int main(int argc, char** argv) { CLI::App app("Cloe " CLOE_ENGINE_VERSION); + app.option_defaults()->always_capture_default(); // Version Command: - engine::VersionOptions version_options; - auto version = app.add_subcommand("version", "Show program version information."); + engine::VersionOptions version_options{}; + auto* version = app.add_subcommand("version", "Show program version information."); version->add_flag("-j,--json", version_options.output_json, "Output version information as JSON data"); version->add_option("-J,--json-indent", version_options.json_indent, "JSON indentation level"); // Usage Command: - engine::UsageOptions usage_options; + engine::UsageOptions usage_options{}; std::string usage_key_or_path; - auto usage = app.add_subcommand("usage", "Show schema or plugin usage information."); + auto* usage = app.add_subcommand("usage", "Show schema or plugin usage information."); usage->add_flag("-j,--json", usage_options.output_json, "Output global/plugin JSON schema"); usage->add_option("-J,--json-indent", usage_options.json_indent, "JSON indentation level"); usage->add_option("files", usage_key_or_path, "Plugin name, key or path to show schema of"); // Dump Command: - engine::DumpOptions dump_options; + engine::DumpOptions dump_options{}; std::vector dump_files; - auto dump = app.add_subcommand("dump", "Dump configuration of (merged) stack files."); + auto* dump = app.add_subcommand("dump", "Dump configuration of (merged) stack files."); dump->add_option("-J,--json-indent", dump_options.json_indent, "JSON indentation level"); dump->add_option("files", dump_files, "Files to read into the stack"); // Check Command: - engine::CheckOptions check_options; + engine::CheckOptions check_options{}; std::vector check_files; - auto check = app.add_subcommand("check", "Validate stack file configurations."); + auto* check = app.add_subcommand("check", "Validate stack file configurations."); check->add_flag("-s,--summarize", check_options.summarize, "Summarize results"); check->add_flag("-j,--json", check_options.output_json, "Output results as JSON data"); check->add_option("-J,--json-indent", check_options.json_indent, "JSON indentation level"); check->add_option("files", check_files, "Files to check"); // Run Command: - engine::RunOptions run_options; - std::vector run_files; - auto run = app.add_subcommand("run", "Run a simulation with (merged) stack files."); + engine::RunOptions run_options{}; + std::vector run_files{}; + auto* run = app.add_subcommand("run", "Run a simulation with (merged) stack files."); run->add_option("-J,--json-indent", run_options.json_indent, "JSON indentation level"); run->add_option("-u,--uuid", run_options.uuid, "Override simulation UUID") ->envname("CLOE_SIMULATION_UUID"); @@ -103,10 +92,10 @@ int main(int argc, char** argv) { ->envname("CLOE_LOG_LEVEL"); // Stack Options: - cloe::StackOptions stack_options; - stack_options.environment.reset(new fable::Environment()); + cloe::StackOptions stack_options{}; + stack_options.environment = std::make_unique(); app.add_option("-p,--plugin-path", stack_options.plugin_paths, - "Scan additional directory for plugins"); + "Scan additional directory for plugins (Env:CLOE_PLUGIN_PATH)"); app.add_option("-i,--ignore", stack_options.ignore_sections, "Ignore sections by JSON pointer syntax"); app.add_flag("--no-builtin-plugins", stack_options.no_builtin_plugins, @@ -165,8 +154,8 @@ int main(int argc, char** argv) { stack_options.environment->insert(CLOE_SIMULATION_UUID_VAR, "${" CLOE_SIMULATION_UUID_VAR "}"); } - auto with_stack_options = [&](auto& opt) -> decltype(opt) { - opt.stack_options = stack_options; + auto with_global_options = [&](auto& opt) -> decltype(opt) { + std::swap(opt.stack_options, stack_options); return opt; }; @@ -175,19 +164,17 @@ int main(int argc, char** argv) { try { if (*version) { return engine::version(version_options); + } else if (*usage) { + return engine::usage(with_global_options(usage_options), usage_key_or_path); + } else if (*dump) { + return engine::dump(with_global_options(dump_options), dump_files); + } else if (*check) { + return engine::check(with_global_options(check_options), check_files); + } else if (*run) { + return engine::run(with_global_options(run_options), run_files); } - if (*usage) { - return engine::usage(with_stack_options(usage_options), usage_key_or_path); - } - if (*dump) { - return engine::dump(with_stack_options(dump_options), dump_files); - } - if (*check) { - return engine::check(with_stack_options(check_options), check_files); - } - if (*run) { - return engine::run(with_stack_options(run_options), run_files); - } + } catch (cloe::ConcludedError& e) { + return EXIT_FAILURE; } catch (std::exception& e) { bool is_logic_error = false; if (dynamic_cast(&e) != nullptr) { @@ -218,9 +205,6 @@ int main(int argc, char** argv) { std::cerr << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n" << std::endl; - - // Write a core dump. - throw; } return EXIT_FAILURE; diff --git a/engine/src/main_check.hpp b/engine/src/main_check.cpp similarity index 57% rename from engine/src/main_check.hpp rename to engine/src/main_check.cpp index 57693f3ba..774df6c79 100644 --- a/engine/src/main_check.hpp +++ b/engine/src/main_check.cpp @@ -15,33 +15,17 @@ * * SPDX-License-Identifier: Apache-2.0 */ -/** - * \file main_check.hpp - * \see main.cpp - * - * This file contains the "check" options and command. - */ - -#pragma once #include // for ostream, cout #include // for string #include // for vector<> -#include "main_stack.hpp" // for Stack, StackOptions, new_stack +#include -namespace engine { - -struct CheckOptions { - cloe::StackOptions stack_options; - std::ostream& output = std::cout; - std::string delimiter = ","; +#include "main_commands.hpp" +#include "stack.hpp" - // Flags: - bool summarize = false; - bool output_json = false; - int json_indent = 2; -}; +namespace engine { /** * Output nothing in the case that a file is valid, and an error message if @@ -49,15 +33,15 @@ struct CheckOptions { * * This mirrors most closely the standard unix command-line philosophy. */ -inline void check_stack(const cloe::StackOptions& opt, const std::vector& files, - bool* ok = nullptr) { - if (ok) { - *ok = false; +void check_stack(const cloe::StackOptions& opt, const std::vector& files, + bool* okay = nullptr) { + if (okay != nullptr) { + *okay = false; } - cloe::Stack s = cloe::new_stack(opt, files); - s.check_completeness(); - if (ok) { - *ok = true; + auto stack = cloe::new_stack(opt, files); + stack.check_completeness(); + if (okay != nullptr) { + *okay = true; } } @@ -66,13 +50,13 @@ inline void check_stack(const cloe::StackOptions& opt, const std::vector& files, - bool* ok = nullptr) { +std::string check_summary(const CheckOptions& opt, const std::vector& files, + bool* okay = nullptr) { cloe::StackOptions stack_opt = opt.stack_options; - stack_opt.error = boost::none; + stack_opt.error = nullptr; try { - check_stack(stack_opt, files, ok); + check_stack(stack_opt, files, okay); return "OK"; } catch (cloe::StackIncompleteError& e) { return "INCOMPLETE (" + std::string(e.what()) + ")"; @@ -87,16 +71,16 @@ inline std::string check_summary(const CheckOptions& opt, const std::vector& files, - bool* ok = nullptr) { +cloe::Json check_json(const CheckOptions& opt, const std::vector& files, + bool* okay = nullptr) { cloe::StackOptions stack_opt = opt.stack_options; - stack_opt.error = boost::none; + stack_opt.error = nullptr; if (opt.summarize) { - return check_summary(opt, files, ok); + return check_summary(opt, files, okay); } else { try { - check_stack(stack_opt, files, ok); + check_stack(stack_opt, files, okay); return nullptr; } catch (cloe::SchemaError& e) { return e; @@ -110,25 +94,25 @@ inline cloe::Json check_json(const CheckOptions& opt, const std::vector& filepaths) { - bool ok = false; +int check_merged(const CheckOptions& opt, const std::vector& filepaths) { + bool okay = false; if (opt.output_json) { - opt.output << check_json(opt, filepaths, &ok).dump(opt.json_indent) << std::endl; + *opt.output << check_json(opt, filepaths, &okay).dump(opt.json_indent) << std::endl; } else if (opt.summarize) { - opt.output << check_summary(opt, filepaths, &ok) << std::endl; + *opt.output << check_summary(opt, filepaths, &okay) << std::endl; } else { try { - check_stack(opt.stack_options, filepaths, &ok); + check_stack(opt.stack_options, filepaths, &okay); } catch (cloe::ConcludedError&) { } catch (std::exception& e) { - opt.output << e.what() << std::endl; + *opt.output << e.what() << std::endl; } } - return ok ? EXIT_SUCCESS : EXIT_FAILURE; + return okay ? EXIT_SUCCESS : EXIT_FAILURE; } -inline int check(const CheckOptions& opt, const std::vector& filepaths) { +int check(const CheckOptions& opt, const std::vector& filepaths) { return check_merged(opt, filepaths); } diff --git a/engine/src/main_commands.hpp b/engine/src/main_commands.hpp new file mode 100644 index 000000000..94aac1d65 --- /dev/null +++ b/engine/src/main_commands.hpp @@ -0,0 +1,119 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file main_commands.hpp + * \see main.cpp + */ + +#include +#include + +#include // for optional<> + +#include "stack_factory.hpp" + +namespace engine { + +struct CheckOptions { + cloe::StackOptions stack_options; + + std::ostream* output = &std::cout; + std::ostream* error = &std::cerr; + std::string delimiter = ","; + + // Flags: + bool summarize = false; + bool output_json = false; + int json_indent = 2; +}; + +int check(const CheckOptions& opt, const std::vector& filepaths); + +struct DumpOptions { + cloe::StackOptions stack_options; + + std::ostream* output = &std::cout; + std::ostream* error = &std::cerr; + + // Flags: + int json_indent = 2; +}; + +int dump(const DumpOptions& opt, const std::vector& filepaths); + +struct RunOptions { + cloe::StackOptions stack_options; + + std::ostream* output = &std::cout; + std::ostream* error = &std::cerr; + + // Options + std::string uuid; + + // Flags: + int json_indent = 2; + bool allow_empty = false; + bool write_output = true; + bool require_success = false; + bool report_progress = true; +}; + +int run(const RunOptions& opt, const std::vector& filepaths); + +struct ShellOptions { + cloe::StackOptions stack_options; + + std::ostream* output = &std::cout; + std::ostream* error = &std::cerr; + + // Options: + std::vector commands; + + // Flags: + std::optional interactive; + bool ignore_errors; +}; + +int shell(const ShellOptions& opt, const std::vector& filepaths); + +struct UsageOptions { + cloe::StackOptions stack_options; + + std::ostream* output = &std::cout; + std::ostream* error = &std::cerr; + + // Flags: + bool plugin_usage = false; + bool output_json = false; + int json_indent = 2; +}; + +int usage(const UsageOptions& opt, const std::string& argument); + +struct VersionOptions { + std::ostream* output = &std::cout; + std::ostream* error = &std::cerr; + + // Flags: + bool output_json = false; + int json_indent = 2; +}; + +int version(const VersionOptions& opt); + +} // namespace engine diff --git a/engine/src/main_dump.hpp b/engine/src/main_dump.cpp similarity index 63% rename from engine/src/main_dump.hpp rename to engine/src/main_dump.cpp index 845482b74..2894c9372 100644 --- a/engine/src/main_dump.hpp +++ b/engine/src/main_dump.cpp @@ -15,35 +15,23 @@ * * SPDX-License-Identifier: Apache-2.0 */ -/** - * \file main_dump.hpp - * \see main.cpp - * - * This file contains the "dump" options and commands. - */ - -#pragma once #include // for ostream, cout #include // for string #include // for vector<> -#include "main_stack.hpp" // for Stack, new_stack +#include -namespace engine { - -struct DumpOptions { - cloe::StackOptions stack_options; - std::ostream& output = std::cout; +#include "main_commands.hpp" // for DumpOptions, new_stack +#include "stack.hpp" // for Stack - // Flags: - int json_indent = 2; -}; +namespace engine { -inline int dump(const DumpOptions& opt, const std::vector& filepaths) { +int dump(const DumpOptions& opt, const std::vector& filepaths) { + assert(opt.output != nullptr && opt.error != nullptr); try { - cloe::Stack s = cloe::new_stack(opt.stack_options, filepaths); - opt.output << s.to_json().dump(opt.json_indent) << std::endl; + auto stack = cloe::new_stack(opt.stack_options, filepaths); + *opt.output << stack.to_json().dump(opt.json_indent) << std::endl; return EXIT_SUCCESS; } catch (cloe::ConcludedError& e) { return EXIT_FAILURE; diff --git a/engine/src/main_run.hpp b/engine/src/main_run.cpp similarity index 68% rename from engine/src/main_run.hpp rename to engine/src/main_run.cpp index 7489d5d40..06a81e5de 100644 --- a/engine/src/main_run.hpp +++ b/engine/src/main_run.cpp @@ -15,14 +15,6 @@ * * SPDX-License-Identifier: Apache-2.0 */ -/** - * \file main_run.hpp - * \see main.cpp - * - * This file contains the "run" options and command. - */ - -#pragma once #include // for signal #include // for getenv @@ -33,6 +25,7 @@ // we still need to support earlier versions of Boost. #define BOOST_ALLOW_DEPRECATED_HEADERS +#include #include // for lexical_cast #include // for random_generator #include @@ -40,48 +33,22 @@ #include // for logger::get #include // for read_conf -#include "main_stack.hpp" // for Stack, new_stack -#include "simulation.hpp" // for Simulation, SimulationResult -#include "stack.hpp" // for Stack +#include "error_handler.hpp" // for conclude_error +#include "main_commands.hpp" // for RunOptions, new_stack +#include "simulation.hpp" // for Simulation, SimulationResult +#include "stack.hpp" // for Stack namespace engine { -void handle_signal(int); - -struct RunOptions { - cloe::StackOptions stack_options; - std::ostream& output = std::cout; - std::ostream& error = std::cerr; - - // Options - std::string uuid; +void handle_signal(int /*sig*/); - // Flags: - int json_indent = 2; - bool allow_empty = false; - bool write_output = true; - bool require_success = false; - bool report_progress = true; -}; +// We need a global instance so that our signal handler has access to it. +Simulation* GLOBAL_SIMULATION_INSTANCE{nullptr}; // NOLINT -Simulation* GLOBAL_SIMULATION_INSTANCE{nullptr}; - -template -auto handle_cloe_error(std::ostream& out, Func f) -> decltype(f()) { - try { - return f(); - } catch (cloe::Error& e) { - out << "Error: " << e.what() << std::endl; - if (e.has_explanation()) { - out << " Note:\n" << fable::indent_string(e.explanation(), " ") << std::endl; - } - throw cloe::ConcludedError(e); - } -} - -inline int run(const RunOptions& opt, const std::vector& filepaths) { +int run(const RunOptions& opt, const std::vector& filepaths) { + assert(opt.output != nullptr && opt.error != nullptr); + auto log = cloe::logger::get("cloe"); cloe::logger::get("cloe")->info("Cloe {}", CLOE_ENGINE_VERSION); - cloe::StackOptions stack_opt = opt.stack_options; // Set the UUID of the simulation: std::string uuid; @@ -92,15 +59,17 @@ inline int run(const RunOptions& opt, const std::vector& filepaths) } else { uuid = boost::lexical_cast(boost::uuids::random_generator()()); } - stack_opt.environment->set(CLOE_SIMULATION_UUID_VAR, uuid); + opt.stack_options.environment->set(CLOE_SIMULATION_UUID_VAR, uuid); // Load the stack file: - cloe::Stack s; + cloe::Stack stack = cloe::new_stack(opt.stack_options); try { - handle_cloe_error(*stack_opt.error, [&]() { - s = cloe::new_stack(stack_opt, filepaths); + cloe::conclude_error(*opt.stack_options.error, [&]() { + for (const auto& file : filepaths) { + cloe::merge_stack(opt.stack_options, stack, file); + } if (!opt.allow_empty) { - s.check_completeness(); + stack.check_completeness(); } }); } catch (cloe::ConcludedError& e) { @@ -108,15 +77,15 @@ inline int run(const RunOptions& opt, const std::vector& filepaths) } // Create simulation: - Simulation sim(s, uuid); + Simulation sim(std::move(stack), uuid); GLOBAL_SIMULATION_INSTANCE = ∼ - std::signal(SIGINT, handle_signal); + std::ignore = std::signal(SIGINT, handle_signal); // Set options: sim.set_report_progress(opt.report_progress); // Run simulation: - auto result = handle_cloe_error(*stack_opt.error, [&]() { return sim.run(); }); + auto result = cloe::conclude_error(*opt.stack_options.error, [&]() { return sim.run(); }); if (result.outcome == SimulationOutcome::NoStart) { // If we didn't get past the initialization phase, don't output any // statistics or write any files, just go home. @@ -127,7 +96,7 @@ inline int run(const RunOptions& opt, const std::vector& filepaths) if (opt.write_output) { sim.write_output(result); } - opt.output << cloe::Json(result).dump(opt.json_indent) << std::endl; + *opt.output << cloe::Json(result).dump(opt.json_indent) << std::endl; switch (result.outcome) { case SimulationOutcome::Success: @@ -160,7 +129,7 @@ inline int run(const RunOptions& opt, const std::vector& filepaths) * by the standard library, so that in the case that we do hang for some * reasons, the user can force abort by sending the signal a third time. */ -inline void handle_signal(int sig) { +void handle_signal(int sig) { static size_t interrupts = 0; switch (sig) { case SIGSEGV: @@ -171,9 +140,9 @@ inline void handle_signal(int sig) { default: std::cerr << std::endl; // print newline so that ^C is on its own line if (++interrupts == 3) { - std::signal(sig, SIG_DFL); // third time goes to the default handler + std::ignore = std::signal(sig, SIG_DFL); // third time goes to the default handler } - if (GLOBAL_SIMULATION_INSTANCE) { + if (GLOBAL_SIMULATION_INSTANCE != nullptr) { GLOBAL_SIMULATION_INSTANCE->signal_abort(); } break; diff --git a/engine/src/main_usage.hpp b/engine/src/main_usage.cpp similarity index 62% rename from engine/src/main_usage.hpp rename to engine/src/main_usage.cpp index f09d9f85f..aefc99425 100644 --- a/engine/src/main_usage.hpp +++ b/engine/src/main_usage.cpp @@ -15,14 +15,6 @@ * * SPDX-License-Identifier: Apache-2.0 */ -/** - * \file main_usage.hpp - * \see main.cpp - * - * This file contains the "usage" options and command. - */ - -#pragma once #include // for ostream #include // for shared_ptr<> @@ -32,28 +24,18 @@ #include // for find_all_config -#include "main_stack.hpp" // for Stack, new_stack +#include "main_commands.hpp" // for new_stack +#include "stack.hpp" // for Stack namespace engine { -struct UsageOptions { - cloe::StackOptions stack_options; - std::ostream& output = std::cout; - std::ostream& error = std::cerr; - - // Flags: - bool plugin_usage = false; - bool output_json = false; - int json_indent = 2; -}; - -void show_usage(cloe::Stack s, std::ostream& output); -void show_plugin_usage(std::shared_ptr p, std::ostream& os, bool json, size_t indent); +void show_usage(const cloe::Stack& stack, std::ostream& out); +void show_plugin_usage(const cloe::Plugin& plugin, std::ostream& out, bool json, int indent); -inline int usage(const UsageOptions& opt, const std::string& argument) { - cloe::Stack s; +int usage(const UsageOptions& opt, const std::string& argument) { + cloe::Stack stack; try { - s = cloe::new_stack(opt.stack_options); + stack = cloe::new_stack(opt.stack_options); } catch (cloe::ConcludedError& e) { return EXIT_FAILURE; } @@ -62,13 +44,13 @@ inline int usage(const UsageOptions& opt, const std::string& argument) { bool result = true; if (argument.empty()) { if (opt.output_json) { - opt.output << s.schema().json_schema().dump(opt.json_indent) << std::endl; + *opt.output << stack.schema().json_schema().dump(opt.json_indent) << std::endl; } else { - show_usage(s, opt.output); + show_usage(stack, *opt.output); } } else { - std::shared_ptr p = s.get_plugin_or_load(argument); - show_plugin_usage(p, opt.output, opt.output_json, opt.json_indent); + std::shared_ptr plugin = stack.get_plugin_or_load(argument); + show_plugin_usage(*plugin, *opt.output, opt.output_json, static_cast(opt.json_indent)); } return result ? EXIT_SUCCESS : EXIT_FAILURE; } @@ -76,47 +58,64 @@ inline int usage(const UsageOptions& opt, const std::string& argument) { // --------------------------------------------------------------------------------------------- // template -void print_plugin_usage(std::ostream& os, const cloe::Plugin& p, const std::string& prefix = " ") { - auto f = p.make(); - auto u = f->schema().usage_compact(); - os << dump_json(u, prefix) << std::endl; +void print_plugin_usage(std::ostream& out, const cloe::Plugin& plugin, + const std::string& prefix = " ") { + auto factory = plugin.make(); + auto usage = factory->schema().usage_compact(); + out << dump_json(usage, prefix) << std::endl; } /** * Print a nicely formatted list of available plugins. + * + * Output looks like: + * + * Available simulators: + * nop [builtin://simulator/nop] + * + * Available controllers: + * basic [/path/to/basic.so] + * nop [builtin://controller/nop] + * + * Available components: + * noisy_lane_sensor [/path/to/noisy_lane_sensor.so] + * speedometer [/path/to/speedometer.so] + * */ -inline void print_available_plugins(const cloe::Stack& s, std::ostream& os, - const std::string& word = "Available") { +void print_available_plugins(const cloe::Stack& stack, std::ostream& out, + const std::string& word = "Available") { const std::string prefix = " "; auto print_available = [&](const std::string& type) { - os << word << " " << type << "s:" << std::endl; + out << word << " " << type << "s:" << std::endl; - std::vector> vec; - for (auto& kv : s.get_all_plugins()) { - if (kv.second->type() == type) { - vec.emplace_back(std::make_pair(kv.second->name(), kv.first)); + std::vector> plugins; + // Get and filter out plugins that are the wanted type. + for (const auto& pair : stack.get_all_plugins()) { + if (pair.second->type() == type) { + plugins.emplace_back(pair.second->name(), pair.first); } } - if (vec.empty()) { - os << prefix << "n/a" << std::endl << std::endl; + if (plugins.empty()) { + out << prefix << "n/a" << std::endl << std::endl; return; } // Calculate how wide the first column needs to be: size_t max_length = 0; - for (auto x : vec) { - if (x.first.size() > max_length) { - max_length = x.first.size(); + for (const auto& pair : plugins) { + if (pair.first.size() > max_length) { + max_length = pair.first.size(); } } // Print the available names: - for (auto x : vec) { - auto n = x.first.size(); - os << prefix << x.first << std::string(max_length - n, ' ') << " [" << x.second << "]\n"; + for (const auto& pair : plugins) { + auto name_len = pair.first.size(); + out << prefix << pair.first << std::string(max_length - name_len, ' ') << " [" << pair.second + << "]\n"; } - os << std::endl; + out << std::endl; }; print_available("simulator"); @@ -127,10 +126,10 @@ inline void print_available_plugins(const cloe::Stack& s, std::ostream& os, /** * Print full program usage. */ -inline void show_usage(cloe::Stack s, std::ostream& os) { - os << fmt::format("Cloe {} ({})", CLOE_ENGINE_VERSION, CLOE_ENGINE_TIMESTAMP) << std::endl; +void show_usage(const cloe::Stack& stack, std::ostream& out) { + out << fmt::format("Cloe {} ({})", CLOE_ENGINE_VERSION, CLOE_ENGINE_TIMESTAMP) << std::endl; - os << R"( + out << R"( Cloe is a simulation middleware tool that ties multiple plugins together into a cohesive and coherent simulation. This is performed based on JSON input that we name "stack files". @@ -225,40 +224,39 @@ Please report any bugs to: cloe-dev@eclipse.org { auto files = cloe::utility::find_all_config(CLOE_XDG_SUFFIX "/config.json"); - if (files.size() != 0) { - os << "Discovered default configuration files:" << std::endl; - for (auto& f : files) { - os << " " << f.native() << std::endl; + if (files.empty()) { + out << "Discovered default configuration files:" << std::endl; + for (auto& file : files) { + out << " " << file.native() << std::endl; } - os << std::endl; + out << std::endl; } } - print_available_plugins(std::move(s), os); + print_available_plugins(stack, out); } -inline void show_plugin_usage(std::shared_ptr p, std::ostream& os, bool json, - size_t indent) { - auto m = p->make(); +void show_plugin_usage(const cloe::Plugin& plugin, std::ostream& out, bool as_json, int indent) { + auto factory = plugin.make(); - if (json) { - cloe::Json js = m->schema().json_schema_qualified(p->path()); - js["title"] = m->name(); - js["description"] = m->description(); - os << js.dump(indent) << std::endl; + if (as_json) { + cloe::Json json = factory->schema().json_schema_qualified(plugin.path()); + json["title"] = factory->name(); + json["description"] = factory->description(); + out << json.dump(indent) << std::endl; return; } - os << "Name: " << m->name() << std::endl; - os << "Type: " << p->type() << std::endl; - os << "Path: "; - if (p->path() == "") { - os << "n/a" << std::endl; + out << "Name: " << factory->name() << std::endl; + out << "Type: " << plugin.type() << std::endl; + out << "Path: "; + if (plugin.path().empty()) { + out << "n/a" << std::endl; } else { - os << p->path() << std::endl; + out << plugin.path() << std::endl; } - os << "Usage: " << m->schema().usage().dump(indent) << std::endl; - os << "Defaults: " << m->to_json().dump(indent) << std::endl; + out << "Usage: " << factory->schema().usage().dump(indent) << std::endl; + out << "Defaults: " << factory->to_json().dump(indent) << std::endl; } } // namespace engine diff --git a/engine/src/main_version.hpp b/engine/src/main_version.cpp similarity index 54% rename from engine/src/main_version.hpp rename to engine/src/main_version.cpp index b2bd47059..833c83dca 100644 --- a/engine/src/main_version.hpp +++ b/engine/src/main_version.cpp @@ -15,46 +15,20 @@ * * SPDX-License-Identifier: Apache-2.0 */ -/** - * \file main_version.hpp - * \see main.cpp - * - * This file contains the "version" options and command. - */ - -#pragma once #include // for ostream, cout +#include #include // for CLOE_PLUGIN_MANIFEST_VERSION #include // for inja_env -#include "stack.hpp" // for CLOE_STACK_VERSION +#include "config.hpp" // for CLOE_STACK_VERSION +#include "main_commands.hpp" // for VersionOptions namespace engine { -struct VersionOptions { - std::ostream& output = std::cout; - - // Flags: - bool output_json = false; - int json_indent = 2; -}; - -inline int version(const VersionOptions& opt) { - cloe::Json v{ - {"engine", CLOE_ENGINE_VERSION}, // from CMakeLists.txt - {"build_date", CLOE_ENGINE_TIMESTAMP}, // from CMakeLists.txt - {"stack", CLOE_STACK_VERSION}, // from "stack.hpp" - {"plugin_manifest", CLOE_PLUGIN_MANIFEST_VERSION}, // from - {"feature_server", CLOE_ENGINE_WITH_SERVER ? true : false}, // from CMakeLists.txt - }; - - if (opt.output_json) { - opt.output << v.dump(opt.json_indent) << std::endl; - } else { - auto env = cloe::utility::inja_env(); - opt.output << env.render(R"(Cloe [[engine]] +static const constexpr char* VERSION_TMPL = + R"(Cloe [[engine]] Engine Version: [[engine]] Build Date: [[build_date]] @@ -62,8 +36,22 @@ Stack: [[stack]] Plugin Manifest: [[plugin_manifest]] Features: server: [[feature_server]] -)", - v); +)"; + +int version(const VersionOptions& opt) { + cloe::Json metadata{ + {"engine", CLOE_ENGINE_VERSION}, // from CMakeLists.txt + {"build_date", CLOE_ENGINE_TIMESTAMP}, // from CMakeLists.txt + {"stack", CLOE_STACK_VERSION}, // from "stack.hpp" + {"plugin_manifest", CLOE_PLUGIN_MANIFEST_VERSION}, // from + {"feature_server", CLOE_ENGINE_WITH_SERVER != 0}, // from CMakeLists.txt + }; + + if (opt.output_json) { + *opt.output << metadata.dump(opt.json_indent) << std::endl; + } else { + auto env = cloe::utility::inja_env(); + *opt.output << env.render(VERSION_TMPL, metadata); } return EXIT_SUCCESS; diff --git a/engine/src/simulation.cpp b/engine/src/simulation.cpp index 279bd6f8e..3031815fd 100644 --- a/engine/src/simulation.cpp +++ b/engine/src/simulation.cpp @@ -84,11 +84,14 @@ #include // for is_directory, is_regular_file, ... +#include // for Controller #include // for AsyncAbort #include // for DirectCallback +#include // for Simulator #include // for CommandFactory, BundleFactory, ... #include // for DEFINE_SET_STATE_ACTION, SetDataActionFactory #include // for INCLUDE_RESOURCE, RESOURCE_HANDLER +#include // for Vehicle #include // for pretty_print #include "simulation_context.hpp" // for SimulationContext @@ -251,7 +254,7 @@ class SimulationMachine } #define DEFINE_STATE(Id, S) DEFINE_STATE_STRUCT(SimulationMachine, SimulationContext, Id, S) - private: + public: DEFINE_STATE(CONNECT, Connect); DEFINE_STATE(START, Start); DEFINE_STATE(STEP_BEGIN, StepBegin); @@ -279,12 +282,8 @@ DEFINE_SET_STATE_ACTION(Stop, "stop", "stop simulation with neither success nor DEFINE_SET_STATE_ACTION(Succeed, "succeed", "stop simulation with success", SimulationMachine, { ptr_->succeed(); }) DEFINE_SET_STATE_ACTION(Fail, "fail", "stop simulation with failure", SimulationMachine, { ptr_->fail(); }) DEFINE_SET_STATE_ACTION(Reset, "reset", "attempt to reset simulation", SimulationMachine, { ptr_->reset(); }) - -DEFINE_SET_STATE_ACTION(KeepAlive, "keep_alive", "keep simulation alive after termination", - SimulationContext, { ptr_->config.engine.keep_alive = true; }) - -DEFINE_SET_STATE_ACTION(ResetStatistics, "reset_statistics", "reset simulation statistics", - SimulationStatistics, { ptr_->reset(); }) +DEFINE_SET_STATE_ACTION(KeepAlive, "keep_alive", "keep simulation alive after termination", SimulationContext, { ptr_->config.engine.keep_alive = true; }) +DEFINE_SET_STATE_ACTION(ResetStatistics, "reset_statistics", "reset simulation statistics", SimulationStatistics, { ptr_->reset(); }) DEFINE_SET_DATA_ACTION(RealtimeFactor, "realtime_factor", "modify the simulation speed", SimulationSync, "factor", double, { @@ -728,33 +727,38 @@ StateId SimulationMachine::Connect::impl(SimulationContext& ctx) { // START --------------------------------------------------------------------------------------- // -StateId SimulationMachine::Start::impl(SimulationContext& ctx) { - logger()->info("Starting simulation..."); - - // Begin execution progress - ctx.progress.exec_begin(); - - // Insert triggers from the config +size_t insert_triggers_from_config(SimulationContext& ctx) { auto r = ctx.coordinator->trigger_registrar(cloe::Source::FILESYSTEM); + size_t count = 0; for (const auto& c : ctx.config.triggers) { if (!ctx.config.engine.triggers_ignore_source && source_is_transient(c.source)) { continue; } try { r->insert_trigger(c.conf()); + count++; } catch (cloe::SchemaError& e) { - logger()->error("Error inserting trigger: {}", e.what()); + ctx.logger()->error("Error inserting trigger: {}", e.what()); std::stringstream s; fable::pretty_print(e, s); - logger()->error("> Message:\n {}", s.str()); - return ABORT; + ctx.logger()->error("> Message:\n {}", s.str()); + throw cloe::ConcludedError(e); } catch (cloe::TriggerError& e) { - logger()->error("Error inserting trigger ({}): {}", e.what(), c.to_json().dump()); - return ABORT; + ctx.logger()->error("Error inserting trigger ({}): {}", e.what(), c.to_json().dump()); + throw cloe::ConcludedError(e); } } + return count; +} + +StateId SimulationMachine::Start::impl(SimulationContext& ctx) { + logger()->info("Starting simulation..."); + + // Begin execution progress + ctx.progress.exec_begin(); // Process initial trigger list + insert_triggers_from_config(ctx); ctx.coordinator->process(ctx.sync); ctx.callback_start->trigger(ctx.sync); @@ -808,7 +812,7 @@ StateId SimulationMachine::StepBegin::impl(SimulationContext& ctx) { // ctx.server->refresh_buffer(); - // Run time-based triggers + // Run cycle- and time-based triggers ctx.callback_loop->trigger(ctx.sync); ctx.callback_time->trigger(ctx.sync); @@ -841,7 +845,9 @@ StateId SimulationMachine::StepSimulators::impl(SimulationContext& ctx) { throw cloe::ModelStop("simulator {} no longer operational", simulator.name()); } if (sim_time != ctx.sync.time()) { - throw cloe::ModelError("simulator {} did not progress to required time: got {}ms, expected {}ms", simulator.name(), sim_time.count()/1'000'000, ctx.sync.time().count()/1'000'000); + throw cloe::ModelError( + "simulator {} did not progress to required time: got {}ms, expected {}ms", + simulator.name(), sim_time.count() / 1'000'000, ctx.sync.time().count() / 1'000'000); } } catch (cloe::ModelReset& e) { throw; @@ -1167,7 +1173,7 @@ StateId SimulationMachine::KeepAlive::impl(SimulationContext& ctx) { // ABORT --------------------------------------------------------------------------------------- // StateId SimulationMachine::Abort::impl(SimulationContext& ctx) { - auto previous_state = state_machine()->previous_state(); + const auto* previous_state = state_machine()->previous_state(); if (previous_state == KEEP_ALIVE) { return DISCONNECT; } else if (previous_state == CONNECT) { @@ -1215,10 +1221,11 @@ SimulationResult Simulation::run() { // Abort handler: SimulationMachine machine; - abort_fn_ = [this, &ctx, &machine]() { + abort_fn_ = [this, &r, &ctx, &machine]() { static size_t requests = 0; logger()->info("Signal caught."); + r.errors.emplace_back("user sent abort signal (e.g. with Ctrl+C)"); requests += 1; if (ctx.progress.is_init_ended()) { if (!ctx.progress.is_exec_ended()) { @@ -1275,9 +1282,11 @@ SimulationResult Simulation::run() { // Run the simulation machine.run(ctx); } catch (cloe::ConcludedError& e) { - // Nothing + r.errors.emplace_back(e.what()); + ctx.outcome = SimulationOutcome::Aborted; } catch (std::exception& e) { - throw; + r.errors.emplace_back(e.what()); + ctx.outcome = SimulationOutcome::Aborted; } try { @@ -1286,6 +1295,7 @@ SimulationResult Simulation::run() { ctx.commander->run_all(config_.engine.hooks_post_disconnect); } catch (cloe::ConcludedError& e) { // TODO(ben): ensure outcome is correctly saved + r.errors.emplace_back(e.what()); } // Wait for any running children to terminate. @@ -1314,7 +1324,7 @@ size_t Simulation::write_output(const SimulationResult& r) const { } size_t files_written = 0; - auto write_file = [&](auto filename, cloe::Json output) { + auto write_file = [&](auto filename, const cloe::Json& output) { if (!filename) { return; } diff --git a/engine/src/simulation.hpp b/engine/src/simulation.hpp index 23fe0c8cf..ac3de4edd 100644 --- a/engine/src/simulation.hpp +++ b/engine/src/simulation.hpp @@ -41,6 +41,7 @@ struct SimulationResult { SimulationSync sync; cloe::Duration elapsed; SimulationOutcome outcome; + std::vector errors; SimulationStatistics statistics; cloe::Json triggers; boost::optional output_dir; @@ -93,8 +94,12 @@ struct SimulationResult { friend void to_json(cloe::Json& j, const SimulationResult& r) { j = cloe::Json{ - {"uuid", r.uuid}, {"statistics", r.statistics}, {"simulation", r.sync}, - {"elapsed", r.elapsed}, {"outcome", r.outcome}, + {"elapsed", r.elapsed}, + {"errors", r.errors}, + {"outcome", r.outcome}, + {"simulation", r.sync}, + {"statistics", r.statistics}, + {"uuid", r.uuid}, }; } }; @@ -145,8 +150,8 @@ class Simulation { void signal_abort(); private: - cloe::Logger logger_; cloe::Stack config_; + cloe::Logger logger_; std::string uuid_; std::function abort_fn_; diff --git a/engine/src/simulation_context.cpp b/engine/src/simulation_context.cpp index 05abf0a68..106d7a4aa 100644 --- a/engine/src/simulation_context.cpp +++ b/engine/src/simulation_context.cpp @@ -15,12 +15,13 @@ * * SPDX-License-Identifier: Apache-2.0 */ -/** - * \file simulation_context.cpp - */ #include "simulation_context.hpp" +#include +#include +#include + namespace engine { std::string SimulationContext::version() const { return CLOE_ENGINE_VERSION; } diff --git a/engine/src/simulation_context.hpp b/engine/src/simulation_context.hpp index d3e81b125..d64bb2268 100644 --- a/engine/src/simulation_context.hpp +++ b/engine/src/simulation_context.hpp @@ -22,32 +22,27 @@ #pragma once -#include // for uint64_t -#include // for function<> -#include // for map<> -#include // for unique_ptr<>, shared_ptr<> -#include // for string -#include // for vector<> - -#include // for optional<> - -#include // for Controller -#include // for Duration -#include // for Registrar -#include // for Simulator +#include // for uint64_t +#include // for function<> +#include // for map<> +#include // for unique_ptr<>, shared_ptr<> +#include // for optional<> +#include // for string +#include // for vector<> + +#include // for Simulator, Controller, Registrar, Vehicle, Duration #include // for Sync #include // for DEFINE_NIL_EVENT #include // for Accumulator #include // for DurationTimer -#include // for Vehicle -#include "coordinator.hpp" // for Coordinator -#include "registrar.hpp" // for Registrar -#include "server.hpp" // for Server -#include "stack.hpp" // for Stack -#include "simulation_progress.hpp" // for SimulationProgress -#include "utility/command.hpp" // for CommandExecuter -#include "utility/time_event.hpp" // for TimeCallback +#include "coordinator.hpp" // for Coordinator +#include "registrar.hpp" // for Registrar +#include "server.hpp" // for Server +#include "simulation_progress.hpp" // for SimulationProgress +#include "stack.hpp" // for Stack +#include "utility/command.hpp" // for CommandExecuter +#include "utility/time_event.hpp" // for TimeCallback namespace engine { @@ -194,7 +189,9 @@ DEFINE_NIL_EVENT(Loop, "loop", "begin of inner simulation loop each cycle") /** * SimulationContext represents the entire context of a running simulation. * - * This clearly separates data from functionality. + * This clearly separates data from functionality. There is no constructor + * where extra initialization is performed. Instead any initialization is + * performed in the simulation states in the `simulation.cpp` file. */ struct SimulationContext { // Setup @@ -216,7 +213,7 @@ struct SimulationContext { std::map> simulators; std::map> vehicles; std::map> controllers; - boost::optional outcome; + std::optional outcome; timer::DurationTimer cycle_duration; bool pause_execution{false}; @@ -233,6 +230,7 @@ struct SimulationContext { public: std::string version() const; + cloe::Logger logger() const { return cloe::logger::get("cloe"); } std::shared_ptr simulation_registrar(); diff --git a/engine/src/stack.cpp b/engine/src/stack.cpp index ee977fe59..086b98ec6 100644 --- a/engine/src/stack.cpp +++ b/engine/src/stack.cpp @@ -136,8 +136,7 @@ inline auto include_prototype() { return IncludeSchema(nullptr, "").file_exists( Conf default_conf_reader(const std::string& filepath) { return Conf{filepath}; } Stack::Stack() - : Confable() - , reserved_ids_({"_", "cloe", "sim", "simulation"}) + : reserved_ids_({"_", "cloe", "sim", "simulation"}) , version(CLOE_STACK_VERSION) , engine_schema(&engine, "engine configuration") , include_schema(&include, include_prototype(), "include configurations") @@ -179,10 +178,16 @@ Stack::Stack(const Stack& other) Stack::reset_schema(); } -Stack::Stack(Stack&& other) : Stack() { swap(*this, other); } +Stack::Stack(Stack&& other) noexcept : Stack() { swap(*this, other); } -Stack& Stack::operator=(Stack other) { +Stack& Stack::operator=(const Stack& other) { // Make use of the copy constructor and then swap. + auto copy = Stack(other); + swap(*this, copy); + return *this; +} + +Stack& Stack::operator=(Stack&& other) noexcept { swap(*this, other); return *this; } @@ -235,6 +240,18 @@ void Stack::reset_schema() { // clang-format on } +void Stack::merge_stackfile(const std::string& filepath) { + this->logger()->info("Include conf: {}", filepath); + Conf config; + try { + config = this->conf_reader_func_(filepath); + } catch (std::exception& e) { + this->logger()->error("Error including conf {}: {}", filepath, e.what()); + throw; + } + from_conf(config); +} + void Stack::apply_plugin_conf(const PluginConf& c) { // 1. Check existence if (!fs::exists(c.plugin_path)) { @@ -303,7 +320,7 @@ void Stack::insert_plugin(std::shared_ptr p, const PluginConf& c) { } // Skip loading if already loaded - if (all_plugins_.count(canon)) { + if (all_plugins_.count(canon) != 0) { logger()->debug("Skip {}", canon); return; } @@ -648,7 +665,7 @@ bool Stack::is_valid() const { void Stack::check_consistency() const { std::map ns; - for (auto x : reserved_ids_) { + for (const auto& x : reserved_ids_) { ns[x] = "reserved keyword"; } @@ -656,7 +673,7 @@ void Stack::check_consistency() const { * Check that the given name does not exist yet. */ auto check = [&](const char* type, const std::string& key) { - if (ns.count(key)) { + if (ns.count(key) != 0) { throw Error("cannot define a new {} with the name '{}': a {} with that name already exists", type, key, ns[key]); } @@ -664,7 +681,7 @@ void Stack::check_consistency() const { }; auto check_has = [&](const char* type, const std::string& key) { - if (!ns.count(key)) { + if (ns.count(key) == 0) { throw Error("cannot find a {} with the name '{}': no entity with that name has been defined", type, key); } else if (ns[key] != type) { @@ -700,7 +717,7 @@ void Stack::check_consistency() const { } void Stack::check_defaults() const { - auto check = [&](auto f, const std::string& name, const std::vector defaults) { + auto check = [&](auto f, const std::string& name, const std::vector& defaults) { auto y = f->clone(); for (const auto& c : defaults) { if (c.name.value_or(name) == name && c.binding.value_or(f->name()) == f->name()) { diff --git a/engine/src/stack.hpp b/engine/src/stack.hpp index 6f6b63629..dc2bdded7 100644 --- a/engine/src/stack.hpp +++ b/engine/src/stack.hpp @@ -44,31 +44,9 @@ #include // for Source #include // for Command +#include "config.hpp" #include "plugin.hpp" // for Plugin -#ifndef CLOE_STACK_VERSION -#define CLOE_STACK_VERSION "4.1" -#endif - -#ifndef CLOE_STACK_SUPPORTED_VERSIONS -#define CLOE_STACK_SUPPORTED_VERSIONS \ - { "4", "4.0", "4.1" } -#endif - -#ifndef CLOE_XDG_SUFFIX -#define CLOE_XDG_SUFFIX "cloe" -#endif - -#ifndef CLOE_CONFIG_HOME -#define CLOE_CONFIG_HOME "${XDG_CONFIG_HOME-${HOME}/.config}/" CLOE_XDG_SUFFIX -#endif - -#ifndef CLOE_DATA_HOME -#define CLOE_DATA_HOME "${XDG_DATA_HOME-${HOME}/.local/share}/" CLOE_XDG_SUFFIX -#endif - -#define CLOE_SIMULATION_UUID_VAR "CLOE_SIMULATION_UUID" - namespace cloe { /** @@ -81,7 +59,7 @@ class PersistentConfable : public Confable { public: const Conf& conf() const { return conf_; } - void from_conf(const Conf& c) { + void from_conf(const Conf& c) override { Confable::from_conf(c); conf_ = c; } @@ -889,8 +867,8 @@ class StackIncompleteError : public Error { public: explicit StackIncompleteError(std::vector&& missing); - std::string all_sections_missing(const std::string& sep = ", ") const; - const std::vector& sections_missing() const { return sections_missing_; } + [[nodiscard]] std::string all_sections_missing(const std::string& sep = ", ") const; + [[nodiscard]] const std::vector& sections_missing() const { return sections_missing_; } private: std::vector sections_missing_; @@ -898,6 +876,10 @@ class StackIncompleteError : public Error { using ConfReader = std::function; +/** + * Stack represents the entire configuration of the engine and the simulation + * to be run. + */ class Stack : public Confable { private: // Constants (1) std::vector reserved_ids_; @@ -937,9 +919,10 @@ class Stack : public Confable { public: // Constructors Stack(); Stack(const Stack& other); - Stack(Stack&& other); - Stack& operator=(Stack other); - ~Stack() = default; + Stack(Stack&& other) noexcept; + Stack& operator=(const Stack& other); + Stack& operator=(Stack&& other) noexcept; + ~Stack() override = default; friend void swap(Stack& left, Stack& right); @@ -956,9 +939,14 @@ class Stack : public Confable { */ void set_conf_reader(ConfReader fn) { assert(fn != nullptr); - conf_reader_func_ = fn; + conf_reader_func_ = std::move(fn); } + /** + * Open the given JSON file and merge it into the stack. + */ + void merge_stackfile(const std::string& filepath); + /** * Try to load and register one or more plugins based on the PluginConf. */ diff --git a/engine/src/stack_component_test.cpp b/engine/src/stack_component_test.cpp new file mode 100644 index 000000000..0889fb165 --- /dev/null +++ b/engine/src/stack_component_test.cpp @@ -0,0 +1,167 @@ +/* + * Copyright 2020 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file stack_component_test.cpp + * \see stack.hpp + * \see stack.cpp + */ + +#include +#include +#include + +#include // for DEFINE_COMPONENT_FACTORY +#include // for EgoSensor +#include // for ObjectSensor +#include // for Json +#include // for assert_from_conf +#include "stack.hpp" // for Stack +using namespace cloe; // NOLINT(build/namespaces) + +namespace { + +struct DummySensorConf : public Confable { + uint64_t freq; + + CONFABLE_SCHEMA(DummySensorConf) { + return Schema{ + {"freq", Schema(&freq, "some frequency")}, + }; + } +}; + +class DummySensor : public NopObjectSensor { + public: + DummySensor(const std::string& name, const DummySensorConf& conf, + std::shared_ptr obs) + : NopObjectSensor(), config_(conf), sensor_(obs) {} + + virtual ~DummySensor() noexcept = default; + + uint64_t get_freq() const { return config_.freq; } + + private: + DummySensorConf config_; + std::shared_ptr sensor_; +}; + +DEFINE_COMPONENT_FACTORY(DummySensorFactory, DummySensorConf, "dummy_object_sensor", + "test component config") + +DEFINE_COMPONENT_FACTORY_MAKE(DummySensorFactory, DummySensor, cloe::ObjectSensor) + +} + +TEST(cloe_stack, deserialization_of_component) { + // Create a sensor component from the given configuration. + std::shared_ptr cf = std::make_shared(); + ComponentConf cc = ComponentConf("dummy_sensor", cf); + fable::assert_from_conf(cc, R"({ + "binding": "dummy_sensor", + "name": "my_dummy_sensor", + "from": "some_obj_sensor", + "args" : { + "freq" : 9 + } + })"); + + // In production code, "some_obj_sensor" would be fetched from a list of all + // available sensors. Skip this step here. + std::vector> from = {std::shared_ptr()}; + auto d = std::dynamic_pointer_cast( + std::shared_ptr(std::move(cf->make(cc.args, from)))); + ASSERT_EQ(d->get_freq(), 9); +} + +namespace { + +class FusionSensor : public NopObjectSensor { + public: + FusionSensor(const std::string& name, const DummySensorConf& conf, + std::vector> obj_sensors, + std::shared_ptr ego_sensor) + : NopObjectSensor(), config_(conf), obj_sensors_(obj_sensors), ego_sensor_(ego_sensor) {} + + virtual ~FusionSensor() noexcept = default; + + uint64_t get_freq() const { return config_.freq; } + + private: + DummySensorConf config_; + std::vector> obj_sensors_; + std::shared_ptr ego_sensor_; +}; + +DEFINE_COMPONENT_FACTORY(FusionSensorFactory, DummySensorConf, "fusion_object_sensor", + "test component config") + +std::unique_ptr<::cloe::Component> FusionSensorFactory::make( + const ::cloe::Conf& c, std::vector> comp_src) const { + decltype(config_) conf{config_}; + if (!c->is_null()) { + conf.from_conf(c); + } + std::vector> obj_sensors; + std::vector> ego_sensors; + for (auto& comp : comp_src) { + auto obj_s = std::dynamic_pointer_cast(comp); + if (obj_s != nullptr) { + obj_sensors.push_back(obj_s); + continue; + } + auto ego_s = std::dynamic_pointer_cast(comp); + if (ego_s != nullptr) { + ego_sensors.push_back(ego_s); + continue; + } + throw Error("FusionSensorFactory: Source component type not supported: from {}", comp->name()); + } + if (ego_sensors.size() != 1) { + throw Error("FusionSensorFactory: {}: Require exactly one ego sensor.", this->name()); + } + return std::make_unique(this->name(), conf, obj_sensors, ego_sensors.front()); +} + +} + +TEST(cloe_stack, deserialization_of_fusion_component) { + // Create a sensor component from the given configuration. + std::shared_ptr cf = std::make_shared(); + ComponentConf cc = ComponentConf("fusion_object_sensor", cf); + fable::assert_from_conf(cc, R"({ + "binding": "fusion_object_sensor", + "name": "my_fusion_sensor", + "from": [ + "ego_sensor0", + "obj_sensor1", + "obj_sensor2" + ], + "args" : { + "freq" : 77 + } + })"); + + // In production code, a component list containing "ego_sensor0", ... would + // be generated. Skip this step here. + std::vector> sensor_subset = { + std::make_shared(), std::make_shared(), + std::make_shared()}; + auto f = std::dynamic_pointer_cast( + std::shared_ptr(std::move(cf->make(cc.args, sensor_subset)))); + ASSERT_EQ(f->get_freq(), 77); +} diff --git a/engine/src/main_stack.cpp b/engine/src/stack_factory.cpp similarity index 74% rename from engine/src/main_stack.cpp rename to engine/src/stack_factory.cpp index 1d0670d87..6d5af1c39 100644 --- a/engine/src/main_stack.cpp +++ b/engine/src/stack_factory.cpp @@ -16,30 +16,29 @@ * SPDX-License-Identifier: Apache-2.0 */ /** - * \file main_stack.cpp - * \see main_stack.hpp + * \file stack_factory.cpp + * \see stack_factory.hpp */ -#include "main_stack.hpp" +#include "stack_factory.hpp" -#include // for ostream, cerr -#include // for string -#include // for vector<> +#include // for path +#include // for ostream, cerr +#include // for string +#include // for vector<> -#include // for path -#include // for optional<> +#include // for path #include // for split_string #include // for merge_config #include // for Environment #include // for pretty_print, read_conf_from_file +#include "config.hpp" // for CLOE_PLUGIN_PATH #include "plugins/nop_controller.hpp" // for NopFactory #include "plugins/nop_simulator.hpp" // for NopFactory #include "stack.hpp" // for Stack -#define CLOE_PLUGIN_PATH "CLOE_PLUGIN_PATH" - namespace cloe { Conf read_conf(const StackOptions& opt, const std::string& filepath) { @@ -72,30 +71,36 @@ void merge_stack(const StackOptions& opt, Stack& s, const std::string& filepath) s.validate_self(); }; - if (!opt.error) { + if (opt.error == nullptr) { + merge(); + return; + } + + try { merge(); - } else { - try { - merge(); - } catch (SchemaError& e) { - fable::pretty_print(e, *opt.error); - throw ConcludedError{e}; - } catch (ConfError& e) { - fable::pretty_print(e, *opt.error); - throw ConcludedError{e}; - } catch (Error& e) { - *opt.error << filepath << ": " << e.what() << std::endl; - if (e.has_explanation()) { - *opt.error << " Note:\n" << fable::indent_string(e.explanation(), " ") << std::endl; - } - throw ConcludedError{e}; - } catch (std::exception& e) { - *opt.error << filepath << ": " << e.what() << std::endl; - throw ConcludedError{e}; + } catch (SchemaError& e) { + fable::pretty_print(e, *opt.error); + throw ConcludedError{e}; + } catch (ConfError& e) { + fable::pretty_print(e, *opt.error); + throw ConcludedError{e}; + } catch (Error& e) { + *opt.error << filepath << ": " << e.what() << std::endl; + if (e.has_explanation()) { + *opt.error << " Note:\n" << fable::indent_string(e.explanation(), " ") << std::endl; } + throw ConcludedError{e}; + } catch (std::exception& e) { + *opt.error << filepath << ": " << e.what() << std::endl; + throw ConcludedError{e}; } } +template +inline bool contains(const std::vector& v, const T& x) { + return std::find(v.begin(), v.end(), x) != v.end(); +} + Stack new_stack(const StackOptions& opt) { Stack s; @@ -123,16 +128,23 @@ Stack new_stack(const StackOptions& opt) { // Setup plugin path: if (!opt.no_system_plugins) { + // FIXME(windows): These paths are linux-specific. s.engine.plugin_path = { "/usr/local/lib/cloe", "/usr/lib/cloe", }; } - std::string paths = opt.environment.get()->get_or(CLOE_PLUGIN_PATH, ""); - for (auto&& p : utility::split_string(std::move(paths), ":")) { + std::string plugin_paths = opt.environment.get()->get_or(CLOE_PLUGIN_PATH, ""); + for (auto&& p : utility::split_string(std::move(plugin_paths), ":")) { + if (contains(s.engine.plugin_path, p)) { + continue; + } s.engine.plugin_path.emplace_back(std::move(p)); } for (const auto& p : opt.plugin_paths) { + if (contains(s.engine.plugin_path, p)) { + continue; + } s.engine.plugin_path.emplace_back(p); } @@ -157,7 +169,6 @@ Stack new_stack(const StackOptions& opt, const std::string& filepath) { if (!filepath.empty()) { merge_stack(opt, s, filepath); } - return s; } diff --git a/engine/src/main_stack.hpp b/engine/src/stack_factory.hpp similarity index 61% rename from engine/src/main_stack.hpp rename to engine/src/stack_factory.hpp index bcb3c0e32..f95e352a5 100644 --- a/engine/src/main_stack.hpp +++ b/engine/src/stack_factory.hpp @@ -16,8 +16,16 @@ * SPDX-License-Identifier: Apache-2.0 */ /** - * \file main_stack.hpp - * \see main_stack.cpp + * \file stack_factory.hpp + * \see stack_factory.cpp + * + * This file provides methods for creating `Stack` objects. + * + * These can be configured through the environment and the CLI by way of + * `StackOptions`. Only one `Stack` object is created in an execution, + * all further stackfiles are merged into the first `Stack` object. + * While this is the current behavior, it is not guaranteed; + * `Stack` is not a singleton. */ #pragma once @@ -27,16 +35,21 @@ #include // for string #include // for vector<> -#include // for optional<> - #include // for Environment -#include "stack.hpp" // for Stack namespace cloe { -// See main.cpp for descriptions of flags. +class Stack; + +/** + * StackOptions contains the configuration required to create new `Stack` objects. + * + * These are provided via the command line and the environment. + * + * \see main.cpp for description of flags + */ struct StackOptions { - boost::optional error = std::cerr; + std::ostream* error = &std::cerr; std::shared_ptr environment; // Flags: @@ -52,12 +65,24 @@ struct StackOptions { bool secure_mode = false; }; +/** + * Create a new empty default Stack from `StackOptions`. + */ Stack new_stack(const StackOptions& opt); +/** + * Create a new Stack from the stackfile provided, respecting `StackOptions`. + */ Stack new_stack(const StackOptions& opt, const std::string& filepath); +/** + * Create a new Stack by merging all stackfiles provided, respecting `StackOptions`. + */ Stack new_stack(const StackOptions& opt, const std::vector& filepaths); +/** + * Merge the provided stackfile into the existing `Stack`, respecting `StackOptions`. + */ void merge_stack(const StackOptions& opt, Stack& s, const std::string& filepath); } // namespace cloe diff --git a/engine/src/stack_test.cpp b/engine/src/stack_test.cpp index c89f93ae4..cee1d60ec 100644 --- a/engine/src/stack_test.cpp +++ b/engine/src/stack_test.cpp @@ -25,9 +25,6 @@ #include #include -#include // for DEFINE_COMPONENT_FACTORY -#include // for EgoSensor -#include // for ObjectSensor #include // for Json #include // for assert_from_conf #include "stack.hpp" // for Stack @@ -214,128 +211,3 @@ TEST(cloe_stack, serialization_with_logging) { Conf c{expect}; ASSERT_TRUE(c.has_pointer("/engine/registry_path")); } - -struct DummySensorConf : public Confable { - uint64_t freq; - - CONFABLE_SCHEMA(DummySensorConf) { - return Schema{ - {"freq", Schema(&freq, "some frequency")}, - }; - } -}; - -class DummySensor : public NopObjectSensor { - public: - DummySensor(const std::string& name, const DummySensorConf& conf, - std::shared_ptr obs) - : NopObjectSensor(), config_(conf), sensor_(obs) {} - - virtual ~DummySensor() noexcept = default; - - uint64_t get_freq() const { return config_.freq; } - - private: - DummySensorConf config_; - std::shared_ptr sensor_; -}; - -DEFINE_COMPONENT_FACTORY(DummySensorFactory, DummySensorConf, "dummy_object_sensor", - "test component config") - -DEFINE_COMPONENT_FACTORY_MAKE(DummySensorFactory, DummySensor, cloe::ObjectSensor) - -TEST(cloe_stack, deserialization_of_component) { - // Create a sensor component from the given configuration. - std::shared_ptr cf = std::make_shared(); - ComponentConf cc = ComponentConf("dummy_sensor", cf); - fable::assert_from_conf(cc, R"({ - "binding": "dummy_sensor", - "name": "my_dummy_sensor", - "from": "some_obj_sensor", - "args" : { - "freq" : 9 - } - })"); - - // In production code, "some_obj_sensor" would be fetched from a list of all - // available sensors. Skip this step here. - std::vector> from = {std::shared_ptr()}; - auto d = std::dynamic_pointer_cast( - std::shared_ptr(std::move(cf->make(cc.args, from)))); - ASSERT_EQ(d->get_freq(), 9); -} - -class FusionSensor : public NopObjectSensor { - public: - FusionSensor(const std::string& name, const DummySensorConf& conf, - std::vector> obj_sensors, - std::shared_ptr ego_sensor) - : NopObjectSensor(), config_(conf), obj_sensors_(obj_sensors), ego_sensor_(ego_sensor) {} - - virtual ~FusionSensor() noexcept = default; - - uint64_t get_freq() const { return config_.freq; } - - private: - DummySensorConf config_; - std::vector> obj_sensors_; - std::shared_ptr ego_sensor_; -}; - -DEFINE_COMPONENT_FACTORY(FusionSensorFactory, DummySensorConf, "fusion_object_sensor", - "test component config") - -std::unique_ptr<::cloe::Component> FusionSensorFactory::make( - const ::cloe::Conf& c, std::vector> comp_src) const { - decltype(config_) conf{config_}; - if (!c->is_null()) { - conf.from_conf(c); - } - std::vector> obj_sensors; - std::vector> ego_sensors; - for (auto& comp : comp_src) { - auto obj_s = std::dynamic_pointer_cast(comp); - if (obj_s != nullptr) { - obj_sensors.push_back(obj_s); - continue; - } - auto ego_s = std::dynamic_pointer_cast(comp); - if (ego_s != nullptr) { - ego_sensors.push_back(ego_s); - continue; - } - throw Error("FusionSensorFactory: Source component type not supported: from {}", comp->name()); - } - if (ego_sensors.size() != 1) { - throw Error("FusionSensorFactory: {}: Require exactly one ego sensor.", this->name()); - } - return std::make_unique(this->name(), conf, obj_sensors, ego_sensors.front()); -} - -TEST(cloe_stack, deserialization_of_fusion_component) { - // Create a sensor component from the given configuration. - std::shared_ptr cf = std::make_shared(); - ComponentConf cc = ComponentConf("fusion_object_sensor", cf); - fable::assert_from_conf(cc, R"({ - "binding": "fusion_object_sensor", - "name": "my_fusion_sensor", - "from": [ - "ego_sensor0", - "obj_sensor1", - "obj_sensor2" - ], - "args" : { - "freq" : 77 - } - })"); - - // In production code, a component list containing "ego_sensor0", ... would - // be generated. Skip this step here. - std::vector> sensor_subset = { - std::make_shared(), std::make_shared(), - std::make_shared()}; - auto f = std::dynamic_pointer_cast( - std::shared_ptr(std::move(cf->make(cc.args, sensor_subset)))); - ASSERT_EQ(f->get_freq(), 77); -} From 2bed94abda8ca4a41954123eeca4b223ef588e1a Mon Sep 17 00:00:00 2001 From: Benjamin Morgan Date: Thu, 16 May 2024 23:58:09 +0200 Subject: [PATCH 06/22] basic: Replace boost::optional with std::optional --- plugins/basic/src/basic.cpp | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/plugins/basic/src/basic.cpp b/plugins/basic/src/basic.cpp index 010ac4397..36ca03fac 100644 --- a/plugins/basic/src/basic.cpp +++ b/plugins/basic/src/basic.cpp @@ -22,16 +22,15 @@ #include "basic.hpp" -#include // for duration<> -#include // for shared_ptr<>, unique_ptr<> -#include // for string -#include // for tie -#include // for pair, make_pair -#include // for vector<> +#include // for duration<> +#include // for shared_ptr<>, unique_ptr<> +#include // for optional +#include // for string +#include // for tie +#include // for pair, make_pair +#include // for vector<> -#include // for optional<> -#include // for Schema -#include // for Optional<> +#include // for Schema #include // for DriverRequest #include // for LatLongActuator @@ -102,10 +101,10 @@ struct AdaptiveCruiseControl { AccConfiguration config; std::shared_ptr vehicle{nullptr}; - bool enabled{false}; // whether the function can be activated - bool active{false}; // whether the function is currently active - size_t distance_algorithm{0}; // index of target distance algorithm - boost::optional target_speed{}; // target speed in [km/h] + bool enabled{false}; // whether the function can be activated + bool active{false}; // whether the function is currently active + size_t distance_algorithm{0}; // index of target distance algorithm + std::optional target_speed{}; // target speed in [km/h] public: explicit AdaptiveCruiseControl(const AccConfiguration& c) : config(c) {} From 12b60b52283632bb38b74657280ccbe3bebc7a82 Mon Sep 17 00:00:00 2001 From: Benjamin Morgan Date: Wed, 24 May 2023 14:24:22 +0200 Subject: [PATCH 07/22] engine: Vendor linenoise library Linenoise-Source: https://github.com/antirez/linenoise Linenoise-Commit: 93b2db9bd4968f76148dd62cdadf050ed50b84b3 Linenoise-Date: 2023-03-27 --- NOTICE.md | 6 + engine/conanfile.py | 1 + engine/vendor/linenoise/.gitignore | 3 + engine/vendor/linenoise/CMakeLists.txt | 36 + engine/vendor/linenoise/LICENSE | 25 + engine/vendor/linenoise/Makefile | 7 + engine/vendor/linenoise/README.markdown | 347 ++++++ engine/vendor/linenoise/example.c | 124 +++ engine/vendor/linenoise/linenoise.c | 1348 +++++++++++++++++++++++ engine/vendor/linenoise/linenoise.h | 113 ++ 10 files changed, 2010 insertions(+) create mode 100644 engine/vendor/linenoise/.gitignore create mode 100644 engine/vendor/linenoise/CMakeLists.txt create mode 100644 engine/vendor/linenoise/LICENSE create mode 100644 engine/vendor/linenoise/Makefile create mode 100644 engine/vendor/linenoise/README.markdown create mode 100644 engine/vendor/linenoise/example.c create mode 100644 engine/vendor/linenoise/linenoise.c create mode 100644 engine/vendor/linenoise/linenoise.h diff --git a/NOTICE.md b/NOTICE.md index ae44c1fd3..7f9d67da1 100644 --- a/NOTICE.md +++ b/NOTICE.md @@ -46,6 +46,12 @@ The following third-party libraries are included in the Cloe repository: - Website: https://jothepro.github.io/doxygen-awesome-css - Source: docs/_vendor/doxygen-awesome +- Linenoise + - License: BSD2 + - License-Source: https://raw.githubusercontent.com/antirez/linenoise/master/LICENSE + - Website: https://github.com/antirez/linenoise + - Source: engine/vendor/linenoise + The following third-party libraries are used by this project (these are usually installed with the help of Conan): diff --git a/engine/conanfile.py b/engine/conanfile.py index 6593e1569..9580d82c6 100644 --- a/engine/conanfile.py +++ b/engine/conanfile.py @@ -37,6 +37,7 @@ class CloeEngine(ConanFile): exports_sources = [ "src/*", "webui/*", + "vendor/*", "CMakeLists.txt", ] diff --git a/engine/vendor/linenoise/.gitignore b/engine/vendor/linenoise/.gitignore new file mode 100644 index 000000000..7ab7825f5 --- /dev/null +++ b/engine/vendor/linenoise/.gitignore @@ -0,0 +1,3 @@ +linenoise_example +*.dSYM +history.txt diff --git a/engine/vendor/linenoise/CMakeLists.txt b/engine/vendor/linenoise/CMakeLists.txt new file mode 100644 index 000000000..2d5afdebe --- /dev/null +++ b/engine/vendor/linenoise/CMakeLists.txt @@ -0,0 +1,36 @@ +cmake_minimum_required(VERSION 3.15 FATAL_ERROR) + +project(linenoise LANGUAGES C) + +include(GNUInstallDirs) + +add_library(linenoise + linenoise.c + linenoise.h +) +add_library(linenoise::linenoise ALIAS linenoise) +target_include_directories(linenoise + PUBLIC + "$" + "$" +) +install(TARGETS linenoise + EXPORT linenoiseTargets + LIBRARY + ARCHIVE + RUNTIME +) +install(FILES linenoise.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +) + +include(CTest) +if(BUILD_TESTING) + add_executable(linenoise-example + example.c + ) + target_link_libraries(linenoise-example + PRIVATE + linenoise + ) +endif() diff --git a/engine/vendor/linenoise/LICENSE b/engine/vendor/linenoise/LICENSE new file mode 100644 index 000000000..18e814865 --- /dev/null +++ b/engine/vendor/linenoise/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2010-2014, Salvatore Sanfilippo +Copyright (c) 2010-2013, Pieter Noordhuis + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/engine/vendor/linenoise/Makefile b/engine/vendor/linenoise/Makefile new file mode 100644 index 000000000..a28541067 --- /dev/null +++ b/engine/vendor/linenoise/Makefile @@ -0,0 +1,7 @@ +linenoise_example: linenoise.h linenoise.c + +linenoise_example: linenoise.c example.c + $(CC) -Wall -W -Os -g -o linenoise_example linenoise.c example.c + +clean: + rm -f linenoise_example diff --git a/engine/vendor/linenoise/README.markdown b/engine/vendor/linenoise/README.markdown new file mode 100644 index 000000000..71313f021 --- /dev/null +++ b/engine/vendor/linenoise/README.markdown @@ -0,0 +1,347 @@ +# Linenoise + +A minimal, zero-config, BSD licensed, readline replacement used in Redis, +MongoDB, Android and many other projects. + +* Single and multi line editing mode with the usual key bindings implemented. +* History handling. +* Completion. +* Hints (suggestions at the right of the prompt as you type). +* Multiplexing mode, with prompt hiding/restoring for asynchronous output. +* About ~850 lines (comments and spaces excluded) of BSD license source code. +* Only uses a subset of VT100 escapes (ANSI.SYS compatible). + +## Can a line editing library be 20k lines of code? + +Line editing with some support for history is a really important feature for command line utilities. Instead of retyping almost the same stuff again and again it's just much better to hit the up arrow and edit on syntax errors, or in order to try a slightly different command. But apparently code dealing with terminals is some sort of Black Magic: readline is 30k lines of code, libedit 20k. Is it reasonable to link small utilities to huge libraries just to get a minimal support for line editing? + +So what usually happens is either: + + * Large programs with configure scripts disabling line editing if readline is not present in the system, or not supporting it at all since readline is GPL licensed and libedit (the BSD clone) is not as known and available as readline is (real world example of this problem: Tclsh). + * Smaller programs not using a configure script not supporting line editing at all (A problem we had with `redis-cli`, for instance). + +The result is a pollution of binaries without line editing support. + +So I spent more or less two hours doing a reality check resulting in this little library: is it *really* needed for a line editing library to be 20k lines of code? Apparently not, it is possibe to get a very small, zero configuration, trivial to embed library, that solves the problem. Smaller programs will just include this, supporting line editing out of the box. Larger programs may use this little library or just checking with configure if readline/libedit is available and resorting to Linenoise if not. + +## Terminals, in 2010. + +Apparently almost every terminal you can happen to use today has some kind of support for basic VT100 escape sequences. So I tried to write a lib using just very basic VT100 features. The resulting library appears to work everywhere I tried to use it, and now can work even on ANSI.SYS compatible terminals, since no +VT220 specific sequences are used anymore. + +The library is currently about 850 lines of code. In order to use it in your project just look at the *example.c* file in the source distribution, it is pretty straightforward. The library supports both a blocking mode and a multiplexing mode, see the API documentation later in this file for more information. + +Linenoise is BSD-licensed code, so you can use both in free software and commercial software. + +## Tested with... + + * Linux text only console ($TERM = linux) + * Linux KDE terminal application ($TERM = xterm) + * Linux xterm ($TERM = xterm) + * Linux Buildroot ($TERM = vt100) + * Mac OS X iTerm ($TERM = xterm) + * Mac OS X default Terminal.app ($TERM = xterm) + * OpenBSD 4.5 through an OSX Terminal.app ($TERM = screen) + * IBM AIX 6.1 + * FreeBSD xterm ($TERM = xterm) + * ANSI.SYS + * Emacs comint mode ($TERM = dumb) + +Please test it everywhere you can and report back! + +## Let's push this forward! + +Patches should be provided in the respect of Linenoise sensibility for small +easy to understand code. + +Send feedbacks to antirez at gmail + +# The API + +Linenoise is very easy to use, and reading the example shipped with the +library should get you up to speed ASAP. Here is a list of API calls +and how to use them. Let's start with the simple blocking mode: + + char *linenoise(const char *prompt); + +This is the main Linenoise call: it shows the user a prompt with line editing +and history capabilities. The prompt you specify is used as a prompt, that is, +it will be printed to the left of the cursor. The library returns a buffer +with the line composed by the user, or NULL on end of file or when there +is an out of memory condition. + +When a tty is detected (the user is actually typing into a terminal session) +the maximum editable line length is `LINENOISE_MAX_LINE`. When instead the +standard input is not a tty, which happens every time you redirect a file +to a program, or use it in an Unix pipeline, there are no limits to the +length of the line that can be returned. + +The returned line should be freed with the `free()` standard system call. +However sometimes it could happen that your program uses a different dynamic +allocation library, so you may also used `linenoiseFree` to make sure the +line is freed with the same allocator it was created. + +The canonical loop used by a program using Linenoise will be something like +this: + + while((line = linenoise("hello> ")) != NULL) { + printf("You wrote: %s\n", line); + linenoiseFree(line); /* Or just free(line) if you use libc malloc. */ + } + +## Single line VS multi line editing + +By default, Linenoise uses single line editing, that is, a single row on the +screen will be used, and as the user types more, the text will scroll towards +left to make room. This works if your program is one where the user is +unlikely to write a lot of text, otherwise multi line editing, where multiple +screens rows are used, can be a lot more comfortable. + +In order to enable multi line editing use the following API call: + + linenoiseSetMultiLine(1); + +You can disable it using `0` as argument. + +## History + +Linenoise supporst history, so that the user does not have to retype +again and again the same things, but can use the down and up arrows in order +to search and re-edit already inserted lines of text. + +The followings are the history API calls: + + int linenoiseHistoryAdd(const char *line); + int linenoiseHistorySetMaxLen(int len); + int linenoiseHistorySave(const char *filename); + int linenoiseHistoryLoad(const char *filename); + +Use `linenoiseHistoryAdd` every time you want to add a new element +to the top of the history (it will be the first the user will see when +using the up arrow). + +Note that for history to work, you have to set a length for the history +(which is zero by default, so history will be disabled if you don't set +a proper one). This is accomplished using the `linenoiseHistorySetMaxLen` +function. + +Linenoise has direct support for persisting the history into an history +file. The functions `linenoiseHistorySave` and `linenoiseHistoryLoad` do +just that. Both functions return -1 on error and 0 on success. + +## Mask mode + +Sometimes it is useful to allow the user to type passwords or other +secrets that should not be displayed. For such situations linenoise supports +a "mask mode" that will just replace the characters the user is typing +with `*` characters, like in the following example: + + $ ./linenoise_example + hello> get mykey + echo: 'get mykey' + hello> /mask + hello> ********* + +You can enable and disable mask mode using the following two functions: + + void linenoiseMaskModeEnable(void); + void linenoiseMaskModeDisable(void); + +## Completion + +Linenoise supports completion, which is the ability to complete the user +input when she or he presses the `` key. + +In order to use completion, you need to register a completion callback, which +is called every time the user presses ``. Your callback will return a +list of items that are completions for the current string. + +The following is an example of registering a completion callback: + + linenoiseSetCompletionCallback(completion); + +The completion must be a function returning `void` and getting as input +a `const char` pointer, which is the line the user has typed so far, and +a `linenoiseCompletions` object pointer, which is used as argument of +`linenoiseAddCompletion` in order to add completions inside the callback. +An example will make it more clear: + + void completion(const char *buf, linenoiseCompletions *lc) { + if (buf[0] == 'h') { + linenoiseAddCompletion(lc,"hello"); + linenoiseAddCompletion(lc,"hello there"); + } + } + +Basically in your completion callback, you inspect the input, and return +a list of items that are good completions by using `linenoiseAddCompletion`. + +If you want to test the completion feature, compile the example program +with `make`, run it, type `h` and press ``. + +## Hints + +Linenoise has a feature called *hints* which is very useful when you +use Linenoise in order to implement a REPL (Read Eval Print Loop) for +a program that accepts commands and arguments, but may also be useful in +other conditions. + +The feature shows, on the right of the cursor, as the user types, hints that +may be useful. The hints can be displayed using a different color compared +to the color the user is typing, and can also be bold. + +For example as the user starts to type `"git remote add"`, with hints it's +possible to show on the right of the prompt a string ` `. + +The feature works similarly to the history feature, using a callback. +To register the callback we use: + + linenoiseSetHintsCallback(hints); + +The callback itself is implemented like this: + + char *hints(const char *buf, int *color, int *bold) { + if (!strcasecmp(buf,"git remote add")) { + *color = 35; + *bold = 0; + return " "; + } + return NULL; + } + +The callback function returns the string that should be displayed or NULL +if no hint is available for the text the user currently typed. The returned +string will be trimmed as needed depending on the number of columns available +on the screen. + +It is possible to return a string allocated in dynamic way, by also registering +a function to deallocate the hint string once used: + + void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *); + +The free hint callback will just receive the pointer and free the string +as needed (depending on how the hits callback allocated it). + +As you can see in the example above, a `color` (in xterm color terminal codes) +can be provided together with a `bold` attribute. If no color is set, the +current terminal foreground color is used. If no bold attribute is set, +non-bold text is printed. + +Color codes are: + + red = 31 + green = 32 + yellow = 33 + blue = 34 + magenta = 35 + cyan = 36 + white = 37; + +## Screen handling + +Sometimes you may want to clear the screen as a result of something the +user typed. You can do this by calling the following function: + + void linenoiseClearScreen(void); + +## Asyncrhronous API + +Sometimes you want to read from the keyboard but also from sockets or other +external events, and at the same time there could be input to display to the +user *while* the user is typing something. Let's call this the "IRC problem", +since if you want to write an IRC client with linenoise, without using +some fully featured libcurses approach, you will surely end having such an +issue. + +Fortunately now a multiplexing friendly API exists, and it is just what the +blocking calls internally use. To start, we need to initialize a linenoise +context like this: + + struct linenoiseState ls; + char buf[1024]; + linenoiseEditStart(&ls,-1,-1,buf,sizeof(buf),"some prompt> "); + +The two -1 and -1 arguments are the stdin/out descriptors. If they are +set to -1, linenoise will just use the default stdin/out file descriptors. +Now as soon as we have data from stdin (and we know it via select(2) or +some other way), we can ask linenoise to read the next character with: + + linenoiseEditFeed(&ls); + +The function returns a `char` pointer: if the user didn't yet press enter +to provide a line to the program, it will return `linenoiseEditMore`, that +means we need to call `linenoiseEditFeed()` again when more data is +available. If the function returns non NULL, then this is a heap allocated +data (to be freed with `linenoiseFree()`) representing the user input. +When the function returns NULL, than the user pressed CTRL-C or CTRL-D +with an empty line, to quit the program, or there was some I/O error. + +After each line is received (or if you want to quit the program, and exit raw mode), the following function needs to be called: + + linenoiseEditStop(&ls); + +To start reading the next line, a new linenoiseEditStart() must +be called, in order to reset the state, and so forth, so a typical event +handler called when the standard input is readable, will work similarly +to the example below: + +``` c +void stdinHasSomeData(void) { + char *line = linenoiseEditFeed(&LineNoiseState); + if (line == linenoiseEditMore) return; + linenoiseEditStop(&LineNoiseState); + if (line == NULL) exit(0); + + printf("line: %s\n", line); + linenoiseFree(line); + linenoiseEditStart(&LineNoiseState,-1,-1,LineNoiseBuffer,sizeof(LineNoiseBuffer),"serial> "); +} +``` + +Now that we have a way to avoid blocking in the user input, we can use +two calls to hide/show the edited line, so that it is possible to also +show some input that we received (from socekts, bluetooth, whatever) on +screen: + + linenoiseHide(&ls); + printf("some data...\n"); + linenoiseShow(&ls); + +To the API calls, the linenoise example C file implements a multiplexing +example using select(2) and the asynchronous API: + +```c + struct linenoiseState ls; + char buf[1024]; + linenoiseEditStart(&ls,-1,-1,buf,sizeof(buf),"hello> "); + + while(1) { + // Select(2) setup code removed... + retval = select(ls.ifd+1, &readfds, NULL, NULL, &tv); + if (retval == -1) { + perror("select()"); + exit(1); + } else if (retval) { + line = linenoiseEditFeed(&ls); + /* A NULL return means: line editing is continuing. + * Otherwise the user hit enter or stopped editing + * (CTRL+C/D). */ + if (line != linenoiseEditMore) break; + } else { + // Timeout occurred + static int counter = 0; + linenoiseHide(&ls); + printf("Async output %d.\n", counter++); + linenoiseShow(&ls); + } + } + linenoiseEditStop(&ls); + if (line == NULL) exit(0); /* Ctrl+D/C. */ +``` + +You can test the example by running the example program with the `--async` option. + +## Related projects + +* [Linenoise NG](https://github.com/arangodb/linenoise-ng) is a fork of Linenoise that aims to add more advanced features like UTF-8 support, Windows support and other features. Uses C++ instead of C as development language. +* [Linenoise-swift](https://github.com/andybest/linenoise-swift) is a reimplementation of Linenoise written in Swift. diff --git a/engine/vendor/linenoise/example.c b/engine/vendor/linenoise/example.c new file mode 100644 index 000000000..3a7f8b372 --- /dev/null +++ b/engine/vendor/linenoise/example.c @@ -0,0 +1,124 @@ +#include +#include +#include +#include +#include "linenoise.h" + +void completion(const char *buf, linenoiseCompletions *lc) { + if (buf[0] == 'h') { + linenoiseAddCompletion(lc,"hello"); + linenoiseAddCompletion(lc,"hello there"); + } +} + +char *hints(const char *buf, int *color, int *bold) { + if (!strcasecmp(buf,"hello")) { + *color = 35; + *bold = 0; + return " World"; + } + return NULL; +} + +int main(int argc, char **argv) { + char *line; + char *prgname = argv[0]; + int async = 0; + + /* Parse options, with --multiline we enable multi line editing. */ + while(argc > 1) { + argc--; + argv++; + if (!strcmp(*argv,"--multiline")) { + linenoiseSetMultiLine(1); + printf("Multi-line mode enabled.\n"); + } else if (!strcmp(*argv,"--keycodes")) { + linenoisePrintKeyCodes(); + exit(0); + } else if (!strcmp(*argv,"--async")) { + async = 1; + } else { + fprintf(stderr, "Usage: %s [--multiline] [--keycodes] [--async]\n", prgname); + exit(1); + } + } + + /* Set the completion callback. This will be called every time the + * user uses the key. */ + linenoiseSetCompletionCallback(completion); + linenoiseSetHintsCallback(hints); + + /* Load history from file. The history file is just a plain text file + * where entries are separated by newlines. */ + linenoiseHistoryLoad("history.txt"); /* Load the history at startup */ + + /* Now this is the main loop of the typical linenoise-based application. + * The call to linenoise() will block as long as the user types something + * and presses enter. + * + * The typed string is returned as a malloc() allocated string by + * linenoise, so the user needs to free() it. */ + + while(1) { + if (!async) { + line = linenoise("hello> "); + if (line == NULL) break; + } else { + /* Asynchronous mode using the multiplexing API: wait for + * data on stdin, and simulate async data coming from some source + * using the select(2) timeout. */ + struct linenoiseState ls; + char buf[1024]; + linenoiseEditStart(&ls,-1,-1,buf,sizeof(buf),"hello> "); + while(1) { + fd_set readfds; + struct timeval tv; + int retval; + + FD_ZERO(&readfds); + FD_SET(ls.ifd, &readfds); + tv.tv_sec = 1; // 1 sec timeout + tv.tv_usec = 0; + + retval = select(ls.ifd+1, &readfds, NULL, NULL, &tv); + if (retval == -1) { + perror("select()"); + exit(1); + } else if (retval) { + line = linenoiseEditFeed(&ls); + /* A NULL return means: line editing is continuing. + * Otherwise the user hit enter or stopped editing + * (CTRL+C/D). */ + if (line != linenoiseEditMore) break; + } else { + // Timeout occurred + static int counter = 0; + linenoiseHide(&ls); + printf("Async output %d.\n", counter++); + linenoiseShow(&ls); + } + } + linenoiseEditStop(&ls); + if (line == NULL) exit(0); /* Ctrl+D/C. */ + } + + /* Do something with the string. */ + if (line[0] != '\0' && line[0] != '/') { + printf("echo: '%s'\n", line); + linenoiseHistoryAdd(line); /* Add to the history. */ + linenoiseHistorySave("history.txt"); /* Save the history on disk. */ + } else if (!strncmp(line,"/historylen",11)) { + /* The "/historylen" command will change the history len. */ + int len = atoi(line+11); + linenoiseHistorySetMaxLen(len); + } else if (!strncmp(line, "/mask", 5)) { + linenoiseMaskModeEnable(); + } else if (!strncmp(line, "/unmask", 7)) { + linenoiseMaskModeDisable(); + } else if (line[0] == '/') { + printf("Unreconized command: %s\n", line); + } + free(line); + } + return 0; +} diff --git a/engine/vendor/linenoise/linenoise.c b/engine/vendor/linenoise/linenoise.c new file mode 100644 index 000000000..5e8aee577 --- /dev/null +++ b/engine/vendor/linenoise/linenoise.c @@ -0,0 +1,1348 @@ +/* linenoise.c -- guerrilla line editing library against the idea that a + * line editing lib needs to be 20,000 lines of C code. + * + * You can find the latest source code at: + * + * http://github.com/antirez/linenoise + * + * Does a number of crazy assumptions that happen to be true in 99.9999% of + * the 2010 UNIX computers around. + * + * ------------------------------------------------------------------------ + * + * Copyright (c) 2010-2023, Salvatore Sanfilippo + * Copyright (c) 2010-2013, Pieter Noordhuis + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * ------------------------------------------------------------------------ + * + * References: + * - http://invisible-island.net/xterm/ctlseqs/ctlseqs.html + * - http://www.3waylabs.com/nw/WWW/products/wizcon/vt220.html + * + * Todo list: + * - Filter bogus Ctrl+ combinations. + * - Win32 support + * + * Bloat: + * - History search like Ctrl+r in readline? + * + * List of escape sequences used by this program, we do everything just + * with three sequences. In order to be so cheap we may have some + * flickering effect with some slow terminal, but the lesser sequences + * the more compatible. + * + * EL (Erase Line) + * Sequence: ESC [ n K + * Effect: if n is 0 or missing, clear from cursor to end of line + * Effect: if n is 1, clear from beginning of line to cursor + * Effect: if n is 2, clear entire line + * + * CUF (CUrsor Forward) + * Sequence: ESC [ n C + * Effect: moves cursor forward n chars + * + * CUB (CUrsor Backward) + * Sequence: ESC [ n D + * Effect: moves cursor backward n chars + * + * The following is used to get the terminal width if getting + * the width with the TIOCGWINSZ ioctl fails + * + * DSR (Device Status Report) + * Sequence: ESC [ 6 n + * Effect: reports the current cusor position as ESC [ n ; m R + * where n is the row and m is the column + * + * When multi line mode is enabled, we also use an additional escape + * sequence. However multi line editing is disabled by default. + * + * CUU (Cursor Up) + * Sequence: ESC [ n A + * Effect: moves cursor up of n chars. + * + * CUD (Cursor Down) + * Sequence: ESC [ n B + * Effect: moves cursor down of n chars. + * + * When linenoiseClearScreen() is called, two additional escape sequences + * are used in order to clear the screen and position the cursor at home + * position. + * + * CUP (Cursor position) + * Sequence: ESC [ H + * Effect: moves the cursor to upper left corner + * + * ED (Erase display) + * Sequence: ESC [ 2 J + * Effect: clear the whole screen + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "linenoise.h" + +#define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 +#define LINENOISE_MAX_LINE 4096 +static char *unsupported_term[] = {"dumb","cons25","emacs",NULL}; +static linenoiseCompletionCallback *completionCallback = NULL; +static linenoiseHintsCallback *hintsCallback = NULL; +static linenoiseFreeHintsCallback *freeHintsCallback = NULL; +static char *linenoiseNoTTY(void); +static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags); +static void refreshLineWithFlags(struct linenoiseState *l, int flags); + +static struct termios orig_termios; /* In order to restore at exit.*/ +static int maskmode = 0; /* Show "***" instead of input. For passwords. */ +static int rawmode = 0; /* For atexit() function to check if restore is needed*/ +static int mlmode = 0; /* Multi line mode. Default is single line. */ +static int atexit_registered = 0; /* Register atexit just 1 time. */ +static int history_max_len = LINENOISE_DEFAULT_HISTORY_MAX_LEN; +static int history_len = 0; +static char **history = NULL; + +enum KEY_ACTION{ + KEY_NULL = 0, /* NULL */ + CTRL_A = 1, /* Ctrl+a */ + CTRL_B = 2, /* Ctrl-b */ + CTRL_C = 3, /* Ctrl-c */ + CTRL_D = 4, /* Ctrl-d */ + CTRL_E = 5, /* Ctrl-e */ + CTRL_F = 6, /* Ctrl-f */ + CTRL_H = 8, /* Ctrl-h */ + TAB = 9, /* Tab */ + CTRL_K = 11, /* Ctrl+k */ + CTRL_L = 12, /* Ctrl+l */ + ENTER = 13, /* Enter */ + CTRL_N = 14, /* Ctrl-n */ + CTRL_P = 16, /* Ctrl-p */ + CTRL_T = 20, /* Ctrl-t */ + CTRL_U = 21, /* Ctrl+u */ + CTRL_W = 23, /* Ctrl+w */ + ESC = 27, /* Escape */ + BACKSPACE = 127 /* Backspace */ +}; + +static void linenoiseAtExit(void); +int linenoiseHistoryAdd(const char *line); +#define REFRESH_CLEAN (1<<0) // Clean the old prompt from the screen +#define REFRESH_WRITE (1<<1) // Rewrite the prompt on the screen. +#define REFRESH_ALL (REFRESH_CLEAN|REFRESH_WRITE) // Do both. +static void refreshLine(struct linenoiseState *l); + +/* Debugging macro. */ +#if 0 +FILE *lndebug_fp = NULL; +#define lndebug(...) \ + do { \ + if (lndebug_fp == NULL) { \ + lndebug_fp = fopen("/tmp/lndebug.txt","a"); \ + fprintf(lndebug_fp, \ + "[%d %d %d] p: %d, rows: %d, rpos: %d, max: %d, oldmax: %d\n", \ + (int)l->len,(int)l->pos,(int)l->oldpos,plen,rows,rpos, \ + (int)l->oldrows,old_rows); \ + } \ + fprintf(lndebug_fp, ", " __VA_ARGS__); \ + fflush(lndebug_fp); \ + } while (0) +#else +#define lndebug(fmt, ...) +#endif + +/* ======================= Low level terminal handling ====================== */ + +/* Enable "mask mode". When it is enabled, instead of the input that + * the user is typing, the terminal will just display a corresponding + * number of asterisks, like "****". This is useful for passwords and other + * secrets that should not be displayed. */ +void linenoiseMaskModeEnable(void) { + maskmode = 1; +} + +/* Disable mask mode. */ +void linenoiseMaskModeDisable(void) { + maskmode = 0; +} + +/* Set if to use or not the multi line mode. */ +void linenoiseSetMultiLine(int ml) { + mlmode = ml; +} + +/* Return true if the terminal name is in the list of terminals we know are + * not able to understand basic escape sequences. */ +static int isUnsupportedTerm(void) { + char *term = getenv("TERM"); + int j; + + if (term == NULL) return 0; + for (j = 0; unsupported_term[j]; j++) + if (!strcasecmp(term,unsupported_term[j])) return 1; + return 0; +} + +/* Raw mode: 1960 magic shit. */ +static int enableRawMode(int fd) { + struct termios raw; + + if (!isatty(STDIN_FILENO)) goto fatal; + if (!atexit_registered) { + atexit(linenoiseAtExit); + atexit_registered = 1; + } + if (tcgetattr(fd,&orig_termios) == -1) goto fatal; + + raw = orig_termios; /* modify the original mode */ + /* input modes: no break, no CR to NL, no parity check, no strip char, + * no start/stop output control. */ + raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + /* output modes - disable post processing */ + raw.c_oflag &= ~(OPOST); + /* control modes - set 8 bit chars */ + raw.c_cflag |= (CS8); + /* local modes - choing off, canonical off, no extended functions, + * no signal chars (^Z,^C) */ + raw.c_lflag &= ~(ECHO | ICANON | IEXTEN | ISIG); + /* control chars - set return condition: min number of bytes and timer. + * We want read to return every single byte, without timeout. */ + raw.c_cc[VMIN] = 1; raw.c_cc[VTIME] = 0; /* 1 byte, no timer */ + + /* put terminal in raw mode after flushing */ + if (tcsetattr(fd,TCSAFLUSH,&raw) < 0) goto fatal; + rawmode = 1; + return 0; + +fatal: + errno = ENOTTY; + return -1; +} + +static void disableRawMode(int fd) { + /* Don't even check the return value as it's too late. */ + if (rawmode && tcsetattr(fd,TCSAFLUSH,&orig_termios) != -1) + rawmode = 0; +} + +/* Use the ESC [6n escape sequence to query the horizontal cursor position + * and return it. On error -1 is returned, on success the position of the + * cursor. */ +static int getCursorPosition(int ifd, int ofd) { + char buf[32]; + int cols, rows; + unsigned int i = 0; + + /* Report cursor location */ + if (write(ofd, "\x1b[6n", 4) != 4) return -1; + + /* Read the response: ESC [ rows ; cols R */ + while (i < sizeof(buf)-1) { + if (read(ifd,buf+i,1) != 1) break; + if (buf[i] == 'R') break; + i++; + } + buf[i] = '\0'; + + /* Parse it. */ + if (buf[0] != ESC || buf[1] != '[') return -1; + if (sscanf(buf+2,"%d;%d",&rows,&cols) != 2) return -1; + return cols; +} + +/* Try to get the number of columns in the current terminal, or assume 80 + * if it fails. */ +static int getColumns(int ifd, int ofd) { + struct winsize ws; + + if (ioctl(1, TIOCGWINSZ, &ws) == -1 || ws.ws_col == 0) { + /* ioctl() failed. Try to query the terminal itself. */ + int start, cols; + + /* Get the initial position so we can restore it later. */ + start = getCursorPosition(ifd,ofd); + if (start == -1) goto failed; + + /* Go to right margin and get position. */ + if (write(ofd,"\x1b[999C",6) != 6) goto failed; + cols = getCursorPosition(ifd,ofd); + if (cols == -1) goto failed; + + /* Restore position. */ + if (cols > start) { + char seq[32]; + snprintf(seq,32,"\x1b[%dD",cols-start); + if (write(ofd,seq,strlen(seq)) == -1) { + /* Can't recover... */ + } + } + return cols; + } else { + return ws.ws_col; + } + +failed: + return 80; +} + +/* Clear the screen. Used to handle ctrl+l */ +void linenoiseClearScreen(void) { + if (write(STDOUT_FILENO,"\x1b[H\x1b[2J",7) <= 0) { + /* nothing to do, just to avoid warning. */ + } +} + +/* Beep, used for completion when there is nothing to complete or when all + * the choices were already shown. */ +static void linenoiseBeep(void) { + fprintf(stderr, "\x7"); + fflush(stderr); +} + +/* ============================== Completion ================================ */ + +/* Free a list of completion option populated by linenoiseAddCompletion(). */ +static void freeCompletions(linenoiseCompletions *lc) { + size_t i; + for (i = 0; i < lc->len; i++) + free(lc->cvec[i]); + if (lc->cvec != NULL) + free(lc->cvec); +} + +/* Called by completeLine() and linenoiseShow() to render the current + * edited line with the proposed completion. If the current completion table + * is already available, it is passed as second argument, otherwise the + * function will use the callback to obtain it. + * + * Flags are the same as refreshLine*(), that is REFRESH_* macros. */ +static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags) { + /* Obtain the table of completions if the caller didn't provide one. */ + linenoiseCompletions ctable = { 0, NULL }; + if (lc == NULL) { + completionCallback(ls->buf,&ctable); + lc = &ctable; + } + + /* Show the edited line with completion if possible, or just refresh. */ + if (ls->completion_idx < lc->len) { + struct linenoiseState saved = *ls; + ls->len = ls->pos = strlen(lc->cvec[ls->completion_idx]); + ls->buf = lc->cvec[ls->completion_idx]; + refreshLineWithFlags(ls,flags); + ls->len = saved.len; + ls->pos = saved.pos; + ls->buf = saved.buf; + } else { + refreshLineWithFlags(ls,flags); + } + + /* Free the completions table if needed. */ + if (lc != &ctable) freeCompletions(&ctable); +} + +/* This is an helper function for linenoiseEdit*() and is called when the + * user types the key in order to complete the string currently in the + * input. + * + * The state of the editing is encapsulated into the pointed linenoiseState + * structure as described in the structure definition. + * + * If the function returns non-zero, the caller should handle the + * returned value as a byte read from the standard input, and process + * it as usually: this basically means that the function may return a byte + * read from the termianl but not processed. Otherwise, if zero is returned, + * the input was consumed by the completeLine() function to navigate the + * possible completions, and the caller should read for the next characters + * from stdin. */ +static int completeLine(struct linenoiseState *ls, int keypressed) { + linenoiseCompletions lc = { 0, NULL }; + int nwritten; + char c = keypressed; + + completionCallback(ls->buf,&lc); + if (lc.len == 0) { + linenoiseBeep(); + ls->in_completion = 0; + } else { + switch(c) { + case 9: /* tab */ + if (ls->in_completion == 0) { + ls->in_completion = 1; + ls->completion_idx = 0; + } else { + ls->completion_idx = (ls->completion_idx+1) % (lc.len+1); + if (ls->completion_idx == lc.len) linenoiseBeep(); + } + c = 0; + break; + case 27: /* escape */ + /* Re-show original buffer */ + if (ls->completion_idx < lc.len) refreshLine(ls); + ls->in_completion = 0; + c = 0; + break; + default: + /* Update buffer and return */ + if (ls->completion_idx < lc.len) { + nwritten = snprintf(ls->buf,ls->buflen,"%s", + lc.cvec[ls->completion_idx]); + ls->len = ls->pos = nwritten; + } + ls->in_completion = 0; + break; + } + + /* Show completion or original buffer */ + if (ls->in_completion && ls->completion_idx < lc.len) { + refreshLineWithCompletion(ls,&lc,REFRESH_ALL); + } else { + refreshLine(ls); + } + } + + freeCompletions(&lc); + return c; /* Return last read character */ +} + +/* Register a callback function to be called for tab-completion. */ +void linenoiseSetCompletionCallback(linenoiseCompletionCallback *fn) { + completionCallback = fn; +} + +/* Register a hits function to be called to show hits to the user at the + * right of the prompt. */ +void linenoiseSetHintsCallback(linenoiseHintsCallback *fn) { + hintsCallback = fn; +} + +/* Register a function to free the hints returned by the hints callback + * registered with linenoiseSetHintsCallback(). */ +void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *fn) { + freeHintsCallback = fn; +} + +/* This function is used by the callback function registered by the user + * in order to add completion options given the input string when the + * user typed . See the example.c source code for a very easy to + * understand example. */ +void linenoiseAddCompletion(linenoiseCompletions *lc, const char *str) { + size_t len = strlen(str); + char *copy, **cvec; + + copy = malloc(len+1); + if (copy == NULL) return; + memcpy(copy,str,len+1); + cvec = realloc(lc->cvec,sizeof(char*)*(lc->len+1)); + if (cvec == NULL) { + free(copy); + return; + } + lc->cvec = cvec; + lc->cvec[lc->len++] = copy; +} + +/* =========================== Line editing ================================= */ + +/* We define a very simple "append buffer" structure, that is an heap + * allocated string where we can append to. This is useful in order to + * write all the escape sequences in a buffer and flush them to the standard + * output in a single call, to avoid flickering effects. */ +struct abuf { + char *b; + int len; +}; + +static void abInit(struct abuf *ab) { + ab->b = NULL; + ab->len = 0; +} + +static void abAppend(struct abuf *ab, const char *s, int len) { + char *new = realloc(ab->b,ab->len+len); + + if (new == NULL) return; + memcpy(new+ab->len,s,len); + ab->b = new; + ab->len += len; +} + +static void abFree(struct abuf *ab) { + free(ab->b); +} + +/* Helper of refreshSingleLine() and refreshMultiLine() to show hints + * to the right of the prompt. */ +void refreshShowHints(struct abuf *ab, struct linenoiseState *l, int plen) { + char seq[64]; + if (hintsCallback && plen+l->len < l->cols) { + int color = -1, bold = 0; + char *hint = hintsCallback(l->buf,&color,&bold); + if (hint) { + int hintlen = strlen(hint); + int hintmaxlen = l->cols-(plen+l->len); + if (hintlen > hintmaxlen) hintlen = hintmaxlen; + if (bold == 1 && color == -1) color = 37; + if (color != -1 || bold != 0) + snprintf(seq,64,"\033[%d;%d;49m",bold,color); + else + seq[0] = '\0'; + abAppend(ab,seq,strlen(seq)); + abAppend(ab,hint,hintlen); + if (color != -1 || bold != 0) + abAppend(ab,"\033[0m",4); + /* Call the function to free the hint returned. */ + if (freeHintsCallback) freeHintsCallback(hint); + } + } +} + +/* Single line low level line refresh. + * + * Rewrite the currently edited line accordingly to the buffer content, + * cursor position, and number of columns of the terminal. + * + * Flags is REFRESH_* macros. The function can just remove the old + * prompt, just write it, or both. */ +static void refreshSingleLine(struct linenoiseState *l, int flags) { + char seq[64]; + size_t plen = strlen(l->prompt); + int fd = l->ofd; + char *buf = l->buf; + size_t len = l->len; + size_t pos = l->pos; + struct abuf ab; + + while((plen+pos) >= l->cols) { + buf++; + len--; + pos--; + } + while (plen+len > l->cols) { + len--; + } + + abInit(&ab); + /* Cursor to left edge */ + snprintf(seq,sizeof(seq),"\r"); + abAppend(&ab,seq,strlen(seq)); + + if (flags & REFRESH_WRITE) { + /* Write the prompt and the current buffer content */ + abAppend(&ab,l->prompt,strlen(l->prompt)); + if (maskmode == 1) { + while (len--) abAppend(&ab,"*",1); + } else { + abAppend(&ab,buf,len); + } + /* Show hits if any. */ + refreshShowHints(&ab,l,plen); + } + + /* Erase to right */ + snprintf(seq,sizeof(seq),"\x1b[0K"); + abAppend(&ab,seq,strlen(seq)); + + if (flags & REFRESH_WRITE) { + /* Move cursor to original position. */ + snprintf(seq,sizeof(seq),"\r\x1b[%dC", (int)(pos+plen)); + abAppend(&ab,seq,strlen(seq)); + } + + if (write(fd,ab.b,ab.len) == -1) {} /* Can't recover from write error. */ + abFree(&ab); +} + +/* Multi line low level line refresh. + * + * Rewrite the currently edited line accordingly to the buffer content, + * cursor position, and number of columns of the terminal. + * + * Flags is REFRESH_* macros. The function can just remove the old + * prompt, just write it, or both. */ +static void refreshMultiLine(struct linenoiseState *l, int flags) { + char seq[64]; + int plen = strlen(l->prompt); + int rows = (plen+l->len+l->cols-1)/l->cols; /* rows used by current buf. */ + int rpos = (plen+l->oldpos+l->cols)/l->cols; /* cursor relative row. */ + int rpos2; /* rpos after refresh. */ + int col; /* colum position, zero-based. */ + int old_rows = l->oldrows; + int fd = l->ofd, j; + struct abuf ab; + + l->oldrows = rows; + + /* First step: clear all the lines used before. To do so start by + * going to the last row. */ + abInit(&ab); + + if (flags & REFRESH_CLEAN) { + if (old_rows-rpos > 0) { + lndebug("go down %d", old_rows-rpos); + snprintf(seq,64,"\x1b[%dB", old_rows-rpos); + abAppend(&ab,seq,strlen(seq)); + } + + /* Now for every row clear it, go up. */ + for (j = 0; j < old_rows-1; j++) { + lndebug("clear+up"); + snprintf(seq,64,"\r\x1b[0K\x1b[1A"); + abAppend(&ab,seq,strlen(seq)); + } + } + + if (flags & REFRESH_ALL) { + /* Clean the top line. */ + lndebug("clear"); + snprintf(seq,64,"\r\x1b[0K"); + abAppend(&ab,seq,strlen(seq)); + } + + if (flags & REFRESH_WRITE) { + /* Write the prompt and the current buffer content */ + abAppend(&ab,l->prompt,strlen(l->prompt)); + if (maskmode == 1) { + unsigned int i; + for (i = 0; i < l->len; i++) abAppend(&ab,"*",1); + } else { + abAppend(&ab,l->buf,l->len); + } + + /* Show hits if any. */ + refreshShowHints(&ab,l,plen); + + /* If we are at the very end of the screen with our prompt, we need to + * emit a newline and move the prompt to the first column. */ + if (l->pos && + l->pos == l->len && + (l->pos+plen) % l->cols == 0) + { + lndebug(""); + abAppend(&ab,"\n",1); + snprintf(seq,64,"\r"); + abAppend(&ab,seq,strlen(seq)); + rows++; + if (rows > (int)l->oldrows) l->oldrows = rows; + } + + /* Move cursor to right position. */ + rpos2 = (plen+l->pos+l->cols)/l->cols; /* Current cursor relative row */ + lndebug("rpos2 %d", rpos2); + + /* Go up till we reach the expected positon. */ + if (rows-rpos2 > 0) { + lndebug("go-up %d", rows-rpos2); + snprintf(seq,64,"\x1b[%dA", rows-rpos2); + abAppend(&ab,seq,strlen(seq)); + } + + /* Set column. */ + col = (plen+(int)l->pos) % (int)l->cols; + lndebug("set col %d", 1+col); + if (col) + snprintf(seq,64,"\r\x1b[%dC", col); + else + snprintf(seq,64,"\r"); + abAppend(&ab,seq,strlen(seq)); + } + + lndebug("\n"); + l->oldpos = l->pos; + + if (write(fd,ab.b,ab.len) == -1) {} /* Can't recover from write error. */ + abFree(&ab); +} + +/* Calls the two low level functions refreshSingleLine() or + * refreshMultiLine() according to the selected mode. */ +static void refreshLineWithFlags(struct linenoiseState *l, int flags) { + if (mlmode) + refreshMultiLine(l,flags); + else + refreshSingleLine(l,flags); +} + +/* Utility function to avoid specifying REFRESH_ALL all the times. */ +static void refreshLine(struct linenoiseState *l) { + refreshLineWithFlags(l,REFRESH_ALL); +} + +/* Hide the current line, when using the multiplexing API. */ +void linenoiseHide(struct linenoiseState *l) { + if (mlmode) + refreshMultiLine(l,REFRESH_CLEAN); + else + refreshSingleLine(l,REFRESH_CLEAN); +} + +/* Show the current line, when using the multiplexing API. */ +void linenoiseShow(struct linenoiseState *l) { + if (l->in_completion) { + refreshLineWithCompletion(l,NULL,REFRESH_WRITE); + } else { + refreshLineWithFlags(l,REFRESH_WRITE); + } +} + +/* Insert the character 'c' at cursor current position. + * + * On error writing to the terminal -1 is returned, otherwise 0. */ +int linenoiseEditInsert(struct linenoiseState *l, char c) { + if (l->len < l->buflen) { + if (l->len == l->pos) { + l->buf[l->pos] = c; + l->pos++; + l->len++; + l->buf[l->len] = '\0'; + if ((!mlmode && l->plen+l->len < l->cols && !hintsCallback)) { + /* Avoid a full update of the line in the + * trivial case. */ + char d = (maskmode==1) ? '*' : c; + if (write(l->ofd,&d,1) == -1) return -1; + } else { + refreshLine(l); + } + } else { + memmove(l->buf+l->pos+1,l->buf+l->pos,l->len-l->pos); + l->buf[l->pos] = c; + l->len++; + l->pos++; + l->buf[l->len] = '\0'; + refreshLine(l); + } + } + return 0; +} + +/* Move cursor on the left. */ +void linenoiseEditMoveLeft(struct linenoiseState *l) { + if (l->pos > 0) { + l->pos--; + refreshLine(l); + } +} + +/* Move cursor on the right. */ +void linenoiseEditMoveRight(struct linenoiseState *l) { + if (l->pos != l->len) { + l->pos++; + refreshLine(l); + } +} + +/* Move cursor to the start of the line. */ +void linenoiseEditMoveHome(struct linenoiseState *l) { + if (l->pos != 0) { + l->pos = 0; + refreshLine(l); + } +} + +/* Move cursor to the end of the line. */ +void linenoiseEditMoveEnd(struct linenoiseState *l) { + if (l->pos != l->len) { + l->pos = l->len; + refreshLine(l); + } +} + +/* Substitute the currently edited line with the next or previous history + * entry as specified by 'dir'. */ +#define LINENOISE_HISTORY_NEXT 0 +#define LINENOISE_HISTORY_PREV 1 +void linenoiseEditHistoryNext(struct linenoiseState *l, int dir) { + if (history_len > 1) { + /* Update the current history entry before to + * overwrite it with the next one. */ + free(history[history_len - 1 - l->history_index]); + history[history_len - 1 - l->history_index] = strdup(l->buf); + /* Show the new entry */ + l->history_index += (dir == LINENOISE_HISTORY_PREV) ? 1 : -1; + if (l->history_index < 0) { + l->history_index = 0; + return; + } else if (l->history_index >= history_len) { + l->history_index = history_len-1; + return; + } + strncpy(l->buf,history[history_len - 1 - l->history_index],l->buflen); + l->buf[l->buflen-1] = '\0'; + l->len = l->pos = strlen(l->buf); + refreshLine(l); + } +} + +/* Delete the character at the right of the cursor without altering the cursor + * position. Basically this is what happens with the "Delete" keyboard key. */ +void linenoiseEditDelete(struct linenoiseState *l) { + if (l->len > 0 && l->pos < l->len) { + memmove(l->buf+l->pos,l->buf+l->pos+1,l->len-l->pos-1); + l->len--; + l->buf[l->len] = '\0'; + refreshLine(l); + } +} + +/* Backspace implementation. */ +void linenoiseEditBackspace(struct linenoiseState *l) { + if (l->pos > 0 && l->len > 0) { + memmove(l->buf+l->pos-1,l->buf+l->pos,l->len-l->pos); + l->pos--; + l->len--; + l->buf[l->len] = '\0'; + refreshLine(l); + } +} + +/* Delete the previosu word, maintaining the cursor at the start of the + * current word. */ +void linenoiseEditDeletePrevWord(struct linenoiseState *l) { + size_t old_pos = l->pos; + size_t diff; + + while (l->pos > 0 && l->buf[l->pos-1] == ' ') + l->pos--; + while (l->pos > 0 && l->buf[l->pos-1] != ' ') + l->pos--; + diff = old_pos - l->pos; + memmove(l->buf+l->pos,l->buf+old_pos,l->len-old_pos+1); + l->len -= diff; + refreshLine(l); +} + +/* This function is part of the multiplexed API of Linenoise, that is used + * in order to implement the blocking variant of the API but can also be + * called by the user directly in an event driven program. It will: + * + * 1. Initialize the linenoise state passed by the user. + * 2. Put the terminal in RAW mode. + * 3. Show the prompt. + * 4. Return control to the user, that will have to call linenoiseEditFeed() + * each time there is some data arriving in the standard input. + * + * The user can also call linenoiseEditHide() and linenoiseEditShow() if it + * is required to show some input arriving asyncronously, without mixing + * it with the currently edited line. + * + * When linenoiseEditFeed() returns non-NULL, the user finished with the + * line editing session (pressed enter CTRL-D/C): in this case the caller + * needs to call linenoiseEditStop() to put back the terminal in normal + * mode. This will not destroy the buffer, as long as the linenoiseState + * is still valid in the context of the caller. + * + * The function returns 0 on success, or -1 if writing to standard output + * fails. If stdin_fd or stdout_fd are set to -1, the default is to use + * STDIN_FILENO and STDOUT_FILENO. + */ +int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) { + /* Populate the linenoise state that we pass to functions implementing + * specific editing functionalities. */ + l->in_completion = 0; + l->ifd = stdin_fd != -1 ? stdin_fd : STDIN_FILENO; + l->ofd = stdout_fd != -1 ? stdout_fd : STDOUT_FILENO; + l->buf = buf; + l->buflen = buflen; + l->prompt = prompt; + l->plen = strlen(prompt); + l->oldpos = l->pos = 0; + l->len = 0; + l->cols = getColumns(stdin_fd, stdout_fd); + l->oldrows = 0; + l->history_index = 0; + + /* Buffer starts empty. */ + l->buf[0] = '\0'; + l->buflen--; /* Make sure there is always space for the nulterm */ + + /* If stdin is not a tty, stop here with the initialization. We + * will actually just read a line from standard input in blocking + * mode later, in linenoiseEditFeed(). */ + if (!isatty(l->ifd)) return 0; + + /* Enter raw mode. */ + if (enableRawMode(l->ifd) == -1) return -1; + + /* The latest history entry is always our current buffer, that + * initially is just an empty string. */ + linenoiseHistoryAdd(""); + + if (write(l->ofd,prompt,l->plen) == -1) return -1; + return 0; +} + +char *linenoiseEditMore = "If you see this, you are misusing the API: when linenoiseEditFeed() is called, if it returns linenoiseEditMore the user is yet editing the line. See the README file for more information."; + +/* This function is part of the multiplexed API of linenoise, see the top + * comment on linenoiseEditStart() for more information. Call this function + * each time there is some data to read from the standard input file + * descriptor. In the case of blocking operations, this function can just be + * called in a loop, and block. + * + * The function returns linenoiseEditMore to signal that line editing is still + * in progress, that is, the user didn't yet pressed enter / CTRL-D. Otherwise + * the function returns the pointer to the heap-allocated buffer with the + * edited line, that the user should free with linenoiseFree(). + * + * On special conditions, NULL is returned and errno is populated: + * + * EAGAIN if the user pressed Ctrl-C + * ENOENT if the user pressed Ctrl-D + * + * Some other errno: I/O error. + */ +char *linenoiseEditFeed(struct linenoiseState *l) { + /* Not a TTY, pass control to line reading without character + * count limits. */ + if (!isatty(l->ifd)) return linenoiseNoTTY(); + + char c; + int nread; + char seq[3]; + + nread = read(l->ifd,&c,1); + if (nread <= 0) return NULL; + + /* Only autocomplete when the callback is set. It returns < 0 when + * there was an error reading from fd. Otherwise it will return the + * character that should be handled next. */ + if ((l->in_completion || c == 9) && completionCallback != NULL) { + c = completeLine(l,c); + /* Return on errors */ + if (c < 0) return NULL; + /* Read next character when 0 */ + if (c == 0) return linenoiseEditMore; + } + + switch(c) { + case ENTER: /* enter */ + history_len--; + free(history[history_len]); + if (mlmode) linenoiseEditMoveEnd(l); + if (hintsCallback) { + /* Force a refresh without hints to leave the previous + * line as the user typed it after a newline. */ + linenoiseHintsCallback *hc = hintsCallback; + hintsCallback = NULL; + refreshLine(l); + hintsCallback = hc; + } + return strdup(l->buf); + case CTRL_C: /* ctrl-c */ + errno = EAGAIN; + return NULL; + case BACKSPACE: /* backspace */ + case 8: /* ctrl-h */ + linenoiseEditBackspace(l); + break; + case CTRL_D: /* ctrl-d, remove char at right of cursor, or if the + line is empty, act as end-of-file. */ + if (l->len > 0) { + linenoiseEditDelete(l); + } else { + history_len--; + free(history[history_len]); + errno = ENOENT; + return NULL; + } + break; + case CTRL_T: /* ctrl-t, swaps current character with previous. */ + if (l->pos > 0 && l->pos < l->len) { + int aux = l->buf[l->pos-1]; + l->buf[l->pos-1] = l->buf[l->pos]; + l->buf[l->pos] = aux; + if (l->pos != l->len-1) l->pos++; + refreshLine(l); + } + break; + case CTRL_B: /* ctrl-b */ + linenoiseEditMoveLeft(l); + break; + case CTRL_F: /* ctrl-f */ + linenoiseEditMoveRight(l); + break; + case CTRL_P: /* ctrl-p */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); + break; + case CTRL_N: /* ctrl-n */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); + break; + case ESC: /* escape sequence */ + /* Read the next two bytes representing the escape sequence. + * Use two calls to handle slow terminals returning the two + * chars at different times. */ + if (read(l->ifd,seq,1) == -1) break; + if (read(l->ifd,seq+1,1) == -1) break; + + /* ESC [ sequences. */ + if (seq[0] == '[') { + if (seq[1] >= '0' && seq[1] <= '9') { + /* Extended escape, read additional byte. */ + if (read(l->ifd,seq+2,1) == -1) break; + if (seq[2] == '~') { + switch(seq[1]) { + case '3': /* Delete key. */ + linenoiseEditDelete(l); + break; + } + } + } else { + switch(seq[1]) { + case 'A': /* Up */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); + break; + case 'B': /* Down */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); + break; + case 'C': /* Right */ + linenoiseEditMoveRight(l); + break; + case 'D': /* Left */ + linenoiseEditMoveLeft(l); + break; + case 'H': /* Home */ + linenoiseEditMoveHome(l); + break; + case 'F': /* End*/ + linenoiseEditMoveEnd(l); + break; + } + } + } + + /* ESC O sequences. */ + else if (seq[0] == 'O') { + switch(seq[1]) { + case 'H': /* Home */ + linenoiseEditMoveHome(l); + break; + case 'F': /* End*/ + linenoiseEditMoveEnd(l); + break; + } + } + break; + default: + if (linenoiseEditInsert(l,c)) return NULL; + break; + case CTRL_U: /* Ctrl+u, delete the whole line. */ + l->buf[0] = '\0'; + l->pos = l->len = 0; + refreshLine(l); + break; + case CTRL_K: /* Ctrl+k, delete from current to end of line. */ + l->buf[l->pos] = '\0'; + l->len = l->pos; + refreshLine(l); + break; + case CTRL_A: /* Ctrl+a, go to the start of the line */ + linenoiseEditMoveHome(l); + break; + case CTRL_E: /* ctrl+e, go to the end of the line */ + linenoiseEditMoveEnd(l); + break; + case CTRL_L: /* ctrl+l, clear screen */ + linenoiseClearScreen(); + refreshLine(l); + break; + case CTRL_W: /* ctrl+w, delete previous word */ + linenoiseEditDeletePrevWord(l); + break; + } + return linenoiseEditMore; +} + +/* This is part of the multiplexed linenoise API. See linenoiseEditStart() + * for more information. This function is called when linenoiseEditFeed() + * returns something different than NULL. At this point the user input + * is in the buffer, and we can restore the terminal in normal mode. */ +void linenoiseEditStop(struct linenoiseState *l) { + if (!isatty(l->ifd)) return; + disableRawMode(l->ifd); + printf("\n"); +} + +/* This just implements a blocking loop for the multiplexed API. + * In many applications that are not event-drivern, we can just call + * the blocking linenoise API, wait for the user to complete the editing + * and return the buffer. */ +static char *linenoiseBlockingEdit(int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) +{ + struct linenoiseState l; + + /* Editing without a buffer is invalid. */ + if (buflen == 0) { + errno = EINVAL; + return NULL; + } + + linenoiseEditStart(&l,stdin_fd,stdout_fd,buf,buflen,prompt); + char *res; + while((res = linenoiseEditFeed(&l)) == linenoiseEditMore); + linenoiseEditStop(&l); + return res; +} + +/* This special mode is used by linenoise in order to print scan codes + * on screen for debugging / development purposes. It is implemented + * by the linenoise_example program using the --keycodes option. */ +void linenoisePrintKeyCodes(void) { + char quit[4]; + + printf("Linenoise key codes debugging mode.\n" + "Press keys to see scan codes. Type 'quit' at any time to exit.\n"); + if (enableRawMode(STDIN_FILENO) == -1) return; + memset(quit,' ',4); + while(1) { + char c; + int nread; + + nread = read(STDIN_FILENO,&c,1); + if (nread <= 0) continue; + memmove(quit,quit+1,sizeof(quit)-1); /* shift string to left. */ + quit[sizeof(quit)-1] = c; /* Insert current char on the right. */ + if (memcmp(quit,"quit",sizeof(quit)) == 0) break; + + printf("'%c' %02x (%d) (type quit to exit)\n", + isprint(c) ? c : '?', (int)c, (int)c); + printf("\r"); /* Go left edge manually, we are in raw mode. */ + fflush(stdout); + } + disableRawMode(STDIN_FILENO); +} + +/* This function is called when linenoise() is called with the standard + * input file descriptor not attached to a TTY. So for example when the + * program using linenoise is called in pipe or with a file redirected + * to its standard input. In this case, we want to be able to return the + * line regardless of its length (by default we are limited to 4k). */ +static char *linenoiseNoTTY(void) { + char *line = NULL; + size_t len = 0, maxlen = 0; + + while(1) { + if (len == maxlen) { + if (maxlen == 0) maxlen = 16; + maxlen *= 2; + char *oldval = line; + line = realloc(line,maxlen); + if (line == NULL) { + if (oldval) free(oldval); + return NULL; + } + } + int c = fgetc(stdin); + if (c == EOF || c == '\n') { + if (c == EOF && len == 0) { + free(line); + return NULL; + } else { + line[len] = '\0'; + return line; + } + } else { + line[len] = c; + len++; + } + } +} + +/* The high level function that is the main API of the linenoise library. + * This function checks if the terminal has basic capabilities, just checking + * for a blacklist of stupid terminals, and later either calls the line + * editing function or uses dummy fgets() so that you will be able to type + * something even in the most desperate of the conditions. */ +char *linenoise(const char *prompt) { + char buf[LINENOISE_MAX_LINE]; + + if (!isatty(STDIN_FILENO)) { + /* Not a tty: read from file / pipe. In this mode we don't want any + * limit to the line size, so we call a function to handle that. */ + return linenoiseNoTTY(); + } else if (isUnsupportedTerm()) { + size_t len; + + printf("%s",prompt); + fflush(stdout); + if (fgets(buf,LINENOISE_MAX_LINE,stdin) == NULL) return NULL; + len = strlen(buf); + while(len && (buf[len-1] == '\n' || buf[len-1] == '\r')) { + len--; + buf[len] = '\0'; + } + return strdup(buf); + } else { + char *retval = linenoiseBlockingEdit(STDIN_FILENO,STDOUT_FILENO,buf,LINENOISE_MAX_LINE,prompt); + return retval; + } +} + +/* This is just a wrapper the user may want to call in order to make sure + * the linenoise returned buffer is freed with the same allocator it was + * created with. Useful when the main program is using an alternative + * allocator. */ +void linenoiseFree(void *ptr) { + if (ptr == linenoiseEditMore) return; // Protect from API misuse. + free(ptr); +} + +/* ================================ History ================================= */ + +/* Free the history, but does not reset it. Only used when we have to + * exit() to avoid memory leaks are reported by valgrind & co. */ +static void freeHistory(void) { + if (history) { + int j; + + for (j = 0; j < history_len; j++) + free(history[j]); + free(history); + } +} + +/* At exit we'll try to fix the terminal to the initial conditions. */ +static void linenoiseAtExit(void) { + disableRawMode(STDIN_FILENO); + freeHistory(); +} + +/* This is the API call to add a new entry in the linenoise history. + * It uses a fixed array of char pointers that are shifted (memmoved) + * when the history max length is reached in order to remove the older + * entry and make room for the new one, so it is not exactly suitable for huge + * histories, but will work well for a few hundred of entries. + * + * Using a circular buffer is smarter, but a bit more complex to handle. */ +int linenoiseHistoryAdd(const char *line) { + char *linecopy; + + if (history_max_len == 0) return 0; + + /* Initialization on first call. */ + if (history == NULL) { + history = malloc(sizeof(char*)*history_max_len); + if (history == NULL) return 0; + memset(history,0,(sizeof(char*)*history_max_len)); + } + + /* Don't add duplicated lines. */ + if (history_len && !strcmp(history[history_len-1], line)) return 0; + + /* Add an heap allocated copy of the line in the history. + * If we reached the max length, remove the older line. */ + linecopy = strdup(line); + if (!linecopy) return 0; + if (history_len == history_max_len) { + free(history[0]); + memmove(history,history+1,sizeof(char*)*(history_max_len-1)); + history_len--; + } + history[history_len] = linecopy; + history_len++; + return 1; +} + +/* Set the maximum length for the history. This function can be called even + * if there is already some history, the function will make sure to retain + * just the latest 'len' elements if the new history length value is smaller + * than the amount of items already inside the history. */ +int linenoiseHistorySetMaxLen(int len) { + char **new; + + if (len < 1) return 0; + if (history) { + int tocopy = history_len; + + new = malloc(sizeof(char*)*len); + if (new == NULL) return 0; + + /* If we can't copy everything, free the elements we'll not use. */ + if (len < tocopy) { + int j; + + for (j = 0; j < tocopy-len; j++) free(history[j]); + tocopy = len; + } + memset(new,0,sizeof(char*)*len); + memcpy(new,history+(history_len-tocopy), sizeof(char*)*tocopy); + free(history); + history = new; + } + history_max_len = len; + if (history_len > history_max_len) + history_len = history_max_len; + return 1; +} + +/* Save the history in the specified file. On success 0 is returned + * otherwise -1 is returned. */ +int linenoiseHistorySave(const char *filename) { + mode_t old_umask = umask(S_IXUSR|S_IRWXG|S_IRWXO); + FILE *fp; + int j; + + fp = fopen(filename,"w"); + umask(old_umask); + if (fp == NULL) return -1; + chmod(filename,S_IRUSR|S_IWUSR); + for (j = 0; j < history_len; j++) + fprintf(fp,"%s\n",history[j]); + fclose(fp); + return 0; +} + +/* Load the history from the specified file. If the file does not exist + * zero is returned and no operation is performed. + * + * If the file exists and the operation succeeded 0 is returned, otherwise + * on error -1 is returned. */ +int linenoiseHistoryLoad(const char *filename) { + FILE *fp = fopen(filename,"r"); + char buf[LINENOISE_MAX_LINE]; + + if (fp == NULL) return -1; + + while (fgets(buf,LINENOISE_MAX_LINE,fp) != NULL) { + char *p; + + p = strchr(buf,'\r'); + if (!p) p = strchr(buf,'\n'); + if (p) *p = '\0'; + linenoiseHistoryAdd(buf); + } + fclose(fp); + return 0; +} diff --git a/engine/vendor/linenoise/linenoise.h b/engine/vendor/linenoise/linenoise.h new file mode 100644 index 000000000..3f0270e3e --- /dev/null +++ b/engine/vendor/linenoise/linenoise.h @@ -0,0 +1,113 @@ +/* linenoise.h -- VERSION 1.0 + * + * Guerrilla line editing library against the idea that a line editing lib + * needs to be 20,000 lines of C code. + * + * See linenoise.c for more information. + * + * ------------------------------------------------------------------------ + * + * Copyright (c) 2010-2023, Salvatore Sanfilippo + * Copyright (c) 2010-2013, Pieter Noordhuis + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef __LINENOISE_H +#define __LINENOISE_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include /* For size_t. */ + +extern char *linenoiseEditMore; + +/* The linenoiseState structure represents the state during line editing. + * We pass this state to functions implementing specific editing + * functionalities. */ +struct linenoiseState { + int in_completion; /* The user pressed TAB and we are now in completion + * mode, so input is handled by completeLine(). */ + size_t completion_idx; /* Index of next completion to propose. */ + int ifd; /* Terminal stdin file descriptor. */ + int ofd; /* Terminal stdout file descriptor. */ + char *buf; /* Edited line buffer. */ + size_t buflen; /* Edited line buffer size. */ + const char *prompt; /* Prompt to display. */ + size_t plen; /* Prompt length. */ + size_t pos; /* Current cursor position. */ + size_t oldpos; /* Previous refresh cursor position. */ + size_t len; /* Current edited line length. */ + size_t cols; /* Number of columns in terminal. */ + size_t oldrows; /* Rows used by last refrehsed line (multiline mode) */ + int history_index; /* The history index we are currently editing. */ +}; + +typedef struct linenoiseCompletions { + size_t len; + char **cvec; +} linenoiseCompletions; + +/* Non blocking API. */ +int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt); +char *linenoiseEditFeed(struct linenoiseState *l); +void linenoiseEditStop(struct linenoiseState *l); +void linenoiseHide(struct linenoiseState *l); +void linenoiseShow(struct linenoiseState *l); + +/* Blocking API. */ +char *linenoise(const char *prompt); +void linenoiseFree(void *ptr); + +/* Completion API. */ +typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *); +typedef char*(linenoiseHintsCallback)(const char *, int *color, int *bold); +typedef void(linenoiseFreeHintsCallback)(void *); +void linenoiseSetCompletionCallback(linenoiseCompletionCallback *); +void linenoiseSetHintsCallback(linenoiseHintsCallback *); +void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *); +void linenoiseAddCompletion(linenoiseCompletions *, const char *); + +/* History API. */ +int linenoiseHistoryAdd(const char *line); +int linenoiseHistorySetMaxLen(int len); +int linenoiseHistorySave(const char *filename); +int linenoiseHistoryLoad(const char *filename); + +/* Other utilities. */ +void linenoiseClearScreen(void); +void linenoiseSetMultiLine(int ml); +void linenoisePrintKeyCodes(void); +void linenoiseMaskModeEnable(void); +void linenoiseMaskModeDisable(void); + +#ifdef __cplusplus +} +#endif + +#endif /* __LINENOISE_H */ From f06ce826b4b87cbf8f0f1b9ef672baa941c907ff Mon Sep 17 00:00:00 2001 From: Benjamin Morgan Date: Thu, 16 May 2024 18:15:31 +0200 Subject: [PATCH 08/22] engine: Add Lua scripting support --- .gitattributes | 1 + .luarc.json | 8 + NOTICE.md | 26 +- conanfile.py | 8 +- docs/reference.rst | 4 + docs/reference/lua-initialization.md | 35 + docs/usage.rst | 5 + docs/usage/lua-cloe-shell.md | 122 + docs/usage/lua-editor-integration.md | 51 + docs/usage/lua-introduction.md | 81 + engine/CMakeLists.txt | 43 + engine/conanfile.py | 6 + engine/lua/cloe-engine/fs.lua | 174 ++ engine/lua/cloe-engine/init.lua | 216 ++ engine/lua/cloe-engine/types.lua | 313 +++ engine/lua/cloe/actions.lua | 106 + engine/lua/cloe/engine.lua | 378 +++ engine/lua/cloe/events.lua | 195 ++ engine/lua/cloe/init.lua | 121 + engine/lua/cloe/luax.lua | 763 ++++++ engine/lua/cloe/system.lua | 108 + engine/lua/cloe/testing.lua | 822 +++++++ engine/lua/cloe/typecheck.lua | 51 + engine/lua/inspect.lua | 379 +++ engine/lua/json.lua | 1869 +++++++++++++++ engine/lua/lust.lua | 269 +++ engine/lua/tableshape.lua | 2354 +++++++++++++++++++ engine/lua/typecheck.lua | 1577 +++++++++++++ engine/src/coordinator.cpp | 75 +- engine/src/coordinator.hpp | 19 +- engine/src/lua_action.cpp | 89 + engine/src/lua_action.hpp | 84 + engine/src/lua_api.cpp | 49 + engine/src/lua_api.hpp | 73 + engine/src/lua_setup.cpp | 342 +++ engine/src/lua_setup.hpp | 112 + engine/src/lua_setup_duration.cpp | 47 + engine/src/lua_setup_fs.cpp | 92 + engine/src/lua_setup_stack.cpp | 44 + engine/src/lua_setup_sync.cpp | 39 + engine/src/lua_stack_test.cpp | 91 + engine/src/main.cpp | 24 +- engine/src/main_commands.hpp | 6 + engine/src/main_run.cpp | 11 +- engine/src/main_shell.cpp | 191 ++ engine/src/registrar.hpp | 11 +- engine/src/simulation.cpp | 44 +- engine/src/simulation.hpp | 6 +- engine/src/simulation_context.hpp | 6 + fable/conanfile.py | 2 +- plugins/basic/src/basic.cpp | 35 + runtime/CMakeLists.txt | 4 + runtime/conanfile.py | 1 + runtime/include/cloe/registrar.hpp | 6 + runtime/include/cloe/trigger.hpp | 6 +- tests/project.lua | 256 ++ tests/report_config.lua | 7 + tests/test_engine_json_schema.json | 3 +- tests/test_engine_lua.json | 15 + tests/test_lua.bats | 107 + tests/test_lua01_include_json.lua | 17 + tests/test_lua02_schedule.lua | 31 + tests/test_lua03_schedule_unpin.lua | 27 + tests/test_lua04_schedule_test.lua | 43 + tests/test_lua05_apply_stack.lua | 67 + tests/test_lua06_apply_stack.lua | 67 + tests/test_lua07_schedule_pause.lua | 74 + tests/test_lua08_apply_project.lua | 40 + tests/test_lua09_no_json.lua | 93 + tests/test_lua10_heavy_cpu.lua | 33 + tests/test_lua11_serial_tests.lua | 47 + tests/test_lua12_fail_after_stop.lua | 34 + tests/test_lua13_bdd_eval.lua | 61 + tests/test_lua_api_cloe_system.lua | 18 + tests/test_lua_api_cloe_typecheck.lua | 11 + tests/test_lua_error_coroutine.lua | 15 + tests/test_lua_error_main.lua | 9 + tests/test_lua_error_schedule.lua | 18 + tests/test_lua_error_schedule_test.lua | 17 + tests/test_lua_error_segfault_on_resume.lua | 75 + 80 files changed, 12745 insertions(+), 34 deletions(-) create mode 100644 .gitattributes create mode 100644 .luarc.json create mode 100644 docs/reference/lua-initialization.md create mode 100644 docs/usage/lua-cloe-shell.md create mode 100644 docs/usage/lua-editor-integration.md create mode 100644 docs/usage/lua-introduction.md create mode 100644 engine/lua/cloe-engine/fs.lua create mode 100644 engine/lua/cloe-engine/init.lua create mode 100644 engine/lua/cloe-engine/types.lua create mode 100644 engine/lua/cloe/actions.lua create mode 100644 engine/lua/cloe/engine.lua create mode 100644 engine/lua/cloe/events.lua create mode 100644 engine/lua/cloe/init.lua create mode 100644 engine/lua/cloe/luax.lua create mode 100644 engine/lua/cloe/system.lua create mode 100644 engine/lua/cloe/testing.lua create mode 100644 engine/lua/cloe/typecheck.lua create mode 100644 engine/lua/inspect.lua create mode 100644 engine/lua/json.lua create mode 100644 engine/lua/lust.lua create mode 100644 engine/lua/tableshape.lua create mode 100644 engine/lua/typecheck.lua create mode 100644 engine/src/lua_action.cpp create mode 100644 engine/src/lua_action.hpp create mode 100644 engine/src/lua_api.cpp create mode 100644 engine/src/lua_api.hpp create mode 100644 engine/src/lua_setup.cpp create mode 100644 engine/src/lua_setup.hpp create mode 100644 engine/src/lua_setup_duration.cpp create mode 100644 engine/src/lua_setup_fs.cpp create mode 100644 engine/src/lua_setup_stack.cpp create mode 100644 engine/src/lua_setup_sync.cpp create mode 100644 engine/src/lua_stack_test.cpp create mode 100644 engine/src/main_shell.cpp create mode 100644 tests/project.lua create mode 100644 tests/report_config.lua create mode 100644 tests/test_engine_lua.json create mode 100755 tests/test_lua.bats create mode 100644 tests/test_lua01_include_json.lua create mode 100644 tests/test_lua02_schedule.lua create mode 100644 tests/test_lua03_schedule_unpin.lua create mode 100644 tests/test_lua04_schedule_test.lua create mode 100644 tests/test_lua05_apply_stack.lua create mode 100644 tests/test_lua06_apply_stack.lua create mode 100644 tests/test_lua07_schedule_pause.lua create mode 100644 tests/test_lua08_apply_project.lua create mode 100644 tests/test_lua09_no_json.lua create mode 100644 tests/test_lua10_heavy_cpu.lua create mode 100644 tests/test_lua11_serial_tests.lua create mode 100644 tests/test_lua12_fail_after_stop.lua create mode 100644 tests/test_lua13_bdd_eval.lua create mode 100644 tests/test_lua_api_cloe_system.lua create mode 100644 tests/test_lua_api_cloe_typecheck.lua create mode 100644 tests/test_lua_error_coroutine.lua create mode 100644 tests/test_lua_error_main.lua create mode 100644 tests/test_lua_error_schedule.lua create mode 100644 tests/test_lua_error_schedule_test.lua create mode 100644 tests/test_lua_error_segfault_on_resume.lua diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..176a458f9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto diff --git a/.luarc.json b/.luarc.json new file mode 100644 index 000000000..401547d13 --- /dev/null +++ b/.luarc.json @@ -0,0 +1,8 @@ +{ + "$schema": "https://raw.githubusercontent.com/sumneko/vscode-lua/master/setting/schema.json", + "workspace.library": ["engine/lua"], + "runtime.version": "Lua 5.4", + "completion.displayContext": 1, + "diagnostics.globals": ["cloe"], + "hint.enable": true +} diff --git a/NOTICE.md b/NOTICE.md index 7f9d67da1..21a0272c0 100644 --- a/NOTICE.md +++ b/NOTICE.md @@ -46,12 +46,36 @@ The following third-party libraries are included in the Cloe repository: - Website: https://jothepro.github.io/doxygen-awesome-css - Source: docs/_vendor/doxygen-awesome +- Typecheck + - License: MIT + - License-Source: https://github.com/gvvaughan/typecheck/raw/master/LICENSE.md + - Website: https://github.com/gvvaughan/typecheck + - Source: engine/lua/typecheck.lua + +- Tableshape + - License: MIT + - License-Source: https://github.com/leafo/tableshape/blob/v2.6.0/README.md + - Website: https://github.com/leafo/tableshape + - Source: engine/lua/tableshape.lua + - Linenoise - License: BSD2 - License-Source: https://raw.githubusercontent.com/antirez/linenoise/master/LICENSE - Website: https://github.com/antirez/linenoise - Source: engine/vendor/linenoise +- Inspect.lua + - License: MIT + - License-Source: https://raw.githubusercontent.com/kikito/inspect.lua/master/MIT-LICENSE.txt + - Website: https://github.com/kikito/inspect.lua + - Source: engine/lua/inspect.lua + +- Lust + - License: MIT + - License-Source: https://raw.githubusercontent.com/bjornbytes/lust/master/LICENSE + - Website: https://github.com/bjornbytes/lust + - Source: engine/lua/lust.lua + The following third-party libraries are used by this project (these are usually installed with the help of Conan): @@ -99,7 +123,7 @@ installed with the help of Conan): - Conan-Package: inja - Boost - - License: Boost + - License: BSL-1.0 - License-Source: https://www.boost.org/LICENSE_1_0.txt - Website: https://www.boost.org - Conan-Package: boost diff --git a/conanfile.py b/conanfile.py index 25d52ad8a..0bdc1a390 100644 --- a/conanfile.py +++ b/conanfile.py @@ -72,7 +72,9 @@ class Cloe(ConanFile): "fable/examples/*", + "engine/lua/*", "engine/webui/*", + "engine/vendor/*", "CMakelists.txt" ] @@ -93,6 +95,7 @@ def requirements(self): self.requires("incbin/cci.20211107"), self.requires("inja/3.4.0") self.requires("nlohmann_json/3.11.3") + self.requires("sol2/3.3.1") self.requires("spdlog/1.11.0") if self.options.engine_server: self.requires("oatpp/1.3.0", private=True) @@ -105,7 +108,6 @@ def requirements(self): def build_requirements(self): self.test_requires("gtest/1.14.0") - self.test_requires("sol2/3.3.0") def layout(self): cmake.cmake_layout(self) @@ -195,14 +197,18 @@ def package_info(self): self.cpp_info.builddirs.append(os.path.join(self.source_folder, "cmake")) self.cpp_info.includedirs.append(os.path.join(self.build_folder, "include")) bindir = os.path.join(self.build_folder, "bin") + luadir = os.path.join(self.source_folder, "engine/lua") libdir = os.path.join(self.build_folder, "lib"); else: self.cpp_info.builddirs.append(os.path.join("lib", "cmake", "cloe")) bindir = os.path.join(self.package_folder, "bin") + luadir = os.path.join(self.package_folder, "lib/cloe/lua") libdir = None self.output.info(f"Appending PATH environment variable: {bindir}") self.runenv_info.prepend_path("PATH", bindir) + self.output.info(f"Appending CLOE_LUA_PATH environment variable: {luadir}") + self.runenv_info.prepend_path("CLOE_LUA_PATH", luadir) if libdir is not None: self.output.info(f"Appending LD_LIBRARY_PATH environment variable: {libdir}") self.runenv_info.append_path("LD_LIBRARY_PATH", libdir) diff --git a/docs/reference.rst b/docs/reference.rst index 7639754de..814e0364a 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -47,6 +47,9 @@ background knowledge for any of the following topics. :doc:`reference/plugins` provides an overview of all plugins that are part of the Cloe distribution. +:doc:`reference/lua-initialization` + provides an overview of how the engine processes a Lua file. + .. toctree:: :hidden: @@ -62,3 +65,4 @@ background knowledge for any of the following topics. reference/events reference/actions reference/plugins + reference/lua-initialization diff --git a/docs/reference/lua-initialization.md b/docs/reference/lua-initialization.md new file mode 100644 index 000000000..570f0eb27 --- /dev/null +++ b/docs/reference/lua-initialization.md @@ -0,0 +1,35 @@ +Lua Initialization +------------------ + +When a Lua file or script is loaded, the Cloe engine provides a preloaded +`cloe` table with a large API. This API is defined in part through a Lua +runtime, and in part from the C++ engine itself. + +The following operations occur when the engine runs a simulation defined +by a Lua file: `cloe-engine run simulation.lua` + +1. Read options from the command line and environment: + + - Lua package path (`--lua-path`, `CLOE_LUA_PATH`) + - Disable system packages (`--no-system-lua`) + - Cloe plugins (`--plugin-path`, `CLOE_PLUGIN_PATH`) + +2. Initialize Cloe Stack + + - Load plugins found in plugin path + +3. Initialize Lua + + - Set lua package path + - Load built-in Lua base libraries (e.g. `os`, `string`) + - Expose Cloe API via `cloe` Lua table + - Load Cloe Lua runtime (located in the package `lib/cloe/lua` directory) + +4. Source input files + + - Files ending with `.lua` are merged as Lua + - Other files are read as JSON + +5. Start simulation + + - Schedule triggers pending from the Lua script diff --git a/docs/usage.rst b/docs/usage.rst index 69a9e2222..e25661f5c 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -16,6 +16,11 @@ guides will give you a better feel for how Cloe works and how it is built up. usage/viewing-cloe-registry usage/running-cloe-webui usage/creating-a-stackfile + + usage/lua-introduction + usage/lua-cloe-shell + usage/lua-editor-integration + usage/configuring-plugins-in-stackfiles usage/writing-modular-stackfiles usage/user-cloe-configuration diff --git a/docs/usage/lua-cloe-shell.md b/docs/usage/lua-cloe-shell.md new file mode 100644 index 000000000..bfd6482b6 --- /dev/null +++ b/docs/usage/lua-cloe-shell.md @@ -0,0 +1,122 @@ +Cloe-Engine Lua Shell +===================== + +Cloe Engine provides a small Lua shell that you can use as a REPL or a way to +run Lua scripts with access to the Cloe API without running a simulation. + +It currently has the following features: + +- Runs Lua files (passed as arguments) +- Runs Lua strings (passed with `-c` option) +- REPL session (by default or with `-i` flag) +- Session history (press Up/Down in interactive session) +- Multi-line editing (experimental) +- Automatic value printing (experimental) + +You can start the Lua REPL with `cloe-engine shell`. + +Hello World +----------- + +Let's demo the various ways we can print "Hello world!" to the console. + +### In the REPL + +Start the REPL and enter in the statement `print("Hello world!")`: +```console +$ cloe-engine shell +Cloe 0.22.0 Lua interactive shell +Press [Ctrl+D] or [Ctrl+C] to exit. +> print("Hello world!") +Hello world! +> +``` + +### Running a command + +Pass the string from the previous example to the shell with `-c`: +```console +$ cloe-engine shell -c 'print("Hello world!")' +Hello world! +``` +You can pass more than one command with `-c` just by repeating it. + + +### Running a file + +Create a file `hello.lua` with the following contents: +```lua +print("Hello world!") +``` +Now run it with `cloe-engine shell`: +```console +$ cloe-engine shell hello.lua +Hello world! +``` + +Multi-Line Editing +------------------ + +If the statement entered on a line looks complete, the shell will run it. +If there is an error in parsing indicating that the statement looks incomplete, +the shell will prompt you for more input: +``` +> print( +>> "Hello world!" +>> ) +Hello world! +``` +This isn't so important for the above example, but for loops, functions, and +if-statements, it is: +``` +> function a() +>> print( +>> "Hello world!" +>> ) +>> end +> a() +Hello world! +``` + +Whitespace +---------- + +Lua does not care about whitespace very much. This means you can replace +all newlines with spaces and the code works the same. + +Consider the following block of code: +```lua +print("...") +io.write("[") +for _, v in ipairs({1, 2, 3}) do + io.write(v .. ",") +end +io.write("]\n") +print("---") +``` +This can be minified in the following simple ways: + +1. Newlines can be replaced with spaces. +2. Parentheses around plain strings and tables can be removed. +3. Spaces before and after commas, quotes, parentheses, and brackets can be removed. + +This leads to the following minified code: +```lua +print"..."io.write"["for _,v in ipairs{1,2,3}do io.write(v..",")end print"]"print"---" +``` +This means that sending whole blocks of code from the command line or from +another application or from code generation is a lot easier. +``` +$ cloe-engine shell -c 'print"..."io.write"["for _,v in ipairs{1,2,3}do io.write(v..",")end print"]"print"---"' +... +[1,2,3,] +--- +``` +Of course I don't expect you'd really do this kind of crazy minification, but +it demonstrates just how little Lua cares about whitespace. + +:::{note} +This one little quirk can provide significant benefits over the Python +scripting language, because it's very easy to compose generated code without +running into syntax errors because of indentation requirements. +::: diff --git a/docs/usage/lua-editor-integration.md b/docs/usage/lua-editor-integration.md new file mode 100644 index 000000000..e6a1fbe71 --- /dev/null +++ b/docs/usage/lua-editor-integration.md @@ -0,0 +1,51 @@ +Lua Editor Integration +====================== + +In order to have the best user-experience when working with Lua files, it's +important to have a good language server up and running to provide +hinting, auto-completion, and linting. + +The Cloe Engine provides definitions for the +[Sumneko Lua Language Server](https://github.com/LuaLS/vscode-lua), +which can be easily integrated in your favorite editor. +For VS Code, install [this extension](https://marketplace.visualstudio.com/items?itemName=sumneko.lua) + +The language server may need a configuration file in order to find the +definitions (though this should not be necessary for the Cloe repository.) + +Configuration File +------------------ + +Let us assume that you have a directory `tests` containing Lua files that you +want to include Cloe definitions for. + +Place in the `tests` directory or in any directory containing `tests` (such +as the repository root) a file named `.luarc.json` containing the following +content: + +```json +{ + "$schema": "https://raw.githubusercontent.com/sumneko/vscode-lua/master/setting/schema.json", + "workspace.library": ["PATH_CONTAINING_LUA_MODULES"], + "runtime.version": "Lua 5.3", + "completion.displayContext": 1, + "diagnostics.globals": [], + "hint.enable": true +} +``` + +Until we develop a plugin for Sumneko containing all the definitions, you need +to tell Sumneko where to find them by hand, where `PATH_CONTAINING_LUA_MODULES` +is above. + +One approach is to make a symlink to the source Lua files in your own +repository and set the workspace library to the symlink: + +```sh +git clone https://github.com/eclipse/cloe ~/cloe +ln -s ~/cloe/engine/lua meta +sed -e 's/PATH_CONTAINING_LUA_MODULES/meta/' -i .luarc.json +``` + +If you are not committing the `.luarc.json` file, then you can also just +specify the absolute path. diff --git a/docs/usage/lua-introduction.md b/docs/usage/lua-introduction.md new file mode 100644 index 000000000..99c01dfb2 --- /dev/null +++ b/docs/usage/lua-introduction.md @@ -0,0 +1,81 @@ +Introduction to Lua +=================== + +From version 0.22 Cloe will support the use of Lua for configuring and +scripting simulations. This is a major improvement in usability but does +require some getting used to. + +[Lua](https://www.lua.org) is a simple language made for embedding in existing +applications and is very widely used in the industry where user extensibility +and scripting is important. It can be [learned](https://learnxinyminutes.com/docs/lua) +quickly. It is also flexible, which allows us to provide an ergonomic +interface to scripting Cloe. + +Setting up Lua +-------------- + +Lua is embedded in the `cloe-engine`, so if you can run `cloe-engine`, you can +use Lua as an input for `cloe-engine run`, and you can also start an interactive +REPL shell with `cloe-engine shell`: + + $ cloe-engine shell + Cloe 0.22.0 Lua interactive shell + Press [Ctrl+D] or [Ctrl+C] to exit. + > print "hello world!" + hello world! + > + +### System Lua + +You can also install Lua as a system program, such as with Apt: + + sudo apt install lua5.4 + +The `lua5.3` package is not a development dependency of Cloe, but it does +provide a very simple `lua` binary that you can run to get a Lua REPL +independently of `cloe-engine`. Unfortunately, because `cloe-engine` exports +modules containing C++ types and functions, `lua` by itself isn't as useful +for most use-cases pertaining to Cloe. + +:::{note} +In `cloe-engine` we embed Lua 5.4, but on Ubuntu versions older than 22.04 the +latest system version available is `lua5.3`. For the most part, the differences +are not important to us. +::: + +### Lua Rocks + +More useful to us than a system Lua REPL is the [LuaRocks](https://luarocks.org/) +package manager. This allows us to easily install and manage third-party Lua +libraries. These are then available to Cloe itself. + +This can be installed on your system with Apt: + + sudo apt install luarocks + +And then packages, called *rocks*, can be installed with the `luarocks` program: + + luarocks install luaposix + +See the LuaRocks website for a list of available rocks and also for more +information on how to use LuaRocks. + +Suggested Exercises +------------------- + +1. Install the latest version of [Lua](https://www.lua.org) on your system. + +2. Read one of these introductions to Lua: + - [Learn Lua in 15 Minutes](https://learnxinyminutes.com/docs/lua/) + - [Programming in Lua](https://www.lua.org/pil/contents.html) + - [Lua-Users Wiki](http://lua-users.org/wiki/LearningLua) + - [Codecademy Course](https://www.codecademy.com/learn/learn-lua) + +3. Launch the Cloe REPL and run the following snippet: + ```lua + cloe.describe(cloe) + ``` + +4. Install the [LuaRocks](https://luarocks.org/) package manager. + +5. Install the `luaposix` rock. diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 2b6d2fd63..b33c387b2 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -12,6 +12,7 @@ find_package(Boost REQUIRED QUIET) find_package(CLI11 REQUIRED QUIET) set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED QUIET) +find_package(sol2 REQUIRED QUIET) include(TargetLinting) include(GNUInstallDirs) @@ -85,8 +86,18 @@ endif() # Library libengine ---------------------------------------------- message(STATUS "Building cloe-enginelib library.") add_library(cloe-enginelib STATIC + src/lua_api.hpp + src/lua_api.cpp + src/lua_setup.hpp + src/lua_setup.cpp + src/lua_setup_duration.cpp + src/lua_setup_fs.cpp + src/lua_setup_stack.cpp + src/lua_setup_sync.cpp src/coordinator.cpp src/coordinator.hpp + src/lua_action.cpp + src/lua_action.hpp src/registrar.hpp # These are added below and depend on CLOE_ENGINE_WITH_SERVER: # src/server.cpp @@ -111,6 +122,9 @@ set_target_properties(cloe-enginelib PROPERTIES ) target_compile_definitions(cloe-enginelib PUBLIC + SOL_ALL_SAFETIES_ON=1 + CLOE_ENGINE_VERSION="${CLOE_ENGINE_VERSION}" + CLOE_ENGINE_TIMESTAMP="${CLOE_ENGINE_TIMESTAMP}" PROJECT_SOURCE_DIR=\"${CMAKE_CURRENT_SOURCE_DIR}\" ) target_include_directories(cloe-enginelib @@ -124,6 +138,7 @@ target_link_libraries(cloe-enginelib cloe::runtime fable::fable boost::boost + sol2::sol2 Threads::Threads ) @@ -140,8 +155,29 @@ else() target_compile_definitions(cloe-enginelib PUBLIC CLOE_ENGINE_WITH_SERVER=0) endif() +if(BUILD_TESTING) + message(STATUS "Building test-enginelib executable.") + add_executable(test-enginelib + src/lua_stack_test.cpp + ) + set_target_properties(test-enginelib PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + ) + target_link_libraries(test-enginelib + GTest::gtest + GTest::gtest_main + Boost::boost + cloe::models + cloe::stacklib + cloe::enginelib + ) + gtest_add_tests(TARGET test-enginelib) +endif() + # Executable --------------------------------------------------------- message(STATUS "Building cloe-engine executable [with server=${CLOE_ENGINE_WITH_SERVER}].") +add_subdirectory(vendor/linenoise) add_executable(cloe-engine src/main.cpp src/main_commands.hpp @@ -149,6 +185,7 @@ add_executable(cloe-engine src/main_dump.cpp src/main_run.cpp src/main_usage.cpp + src/main_shell.cpp src/main_version.cpp ) set_target_properties(cloe-engine PROPERTIES @@ -170,6 +207,7 @@ target_link_libraries(cloe-engine cloe::stacklib cloe::enginelib CLI11::CLI11 + linenoise::linenoise ) # Installation ------------------------------------------------------- @@ -177,3 +215,8 @@ install(TARGETS cloe-engine RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} ) + +install( + DIRECTORY lua/ + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cloe/lua +) diff --git a/engine/conanfile.py b/engine/conanfile.py index 9580d82c6..4b1e468b5 100644 --- a/engine/conanfile.py +++ b/engine/conanfile.py @@ -36,6 +36,7 @@ class CloeEngine(ConanFile): no_copy_source = True exports_sources = [ "src/*", + "lua/*", "webui/*", "vendor/*", "CMakeLists.txt", @@ -53,6 +54,7 @@ def requirements(self): self.requires(f"cloe-runtime/{self.version}@cloe/develop") self.requires(f"cloe-models/{self.version}@cloe/develop") self.requires("cli11/2.3.2", private=True) + self.requires("sol2/3.3.1") if self.options.server: self.requires(f"cloe-oak/{self.version}@cloe/develop", private=True) self.requires("boost/1.74.0") @@ -114,8 +116,12 @@ def package_info(self): self.cpp_info.system_libs.append("dl") if self.in_local_cache: bindir = os.path.join(self.package_folder, "bin") + luadir = os.path.join(self.package_folder, "lib/cloe/lua") else: # editable mode bindir = os.path.join(self.build_folder) + luadir = os.path.join(self.source_folder, "lua") self.output.info(f"Appending PATH environment variable: {bindir}") self.runenv_info.prepend_path("PATH", bindir) + self.output.info(f"Appending CLOE_LUA_PATH environment variable: {luadir}") + self.runenv_info.prepend_path("CLOE_LUA_PATH", luadir) diff --git a/engine/lua/cloe-engine/fs.lua b/engine/lua/cloe-engine/fs.lua new file mode 100644 index 000000000..116b64c34 --- /dev/null +++ b/engine/lua/cloe-engine/fs.lua @@ -0,0 +1,174 @@ +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- + +--- +--- @meta cloe-engine.fs +--- +--- This file contains the type annotations of the `cloe-engine.fs` module, +--- which are exported by the cloe-engine executable. +--- + +local fs = {} + +local unavailable = require("cloe-engine").unavailable + +--- Return the basename of the filepath. +--- +--- Examples: +--- +--- assert fs.basename("/bin/bash") == "bash" +--- assert fs.basename("c:\\path") == "c:\\path" -- on linux +--- +--- @param path string filepath +--- @return string # basename of file without parent +--- @nodiscard +function fs.basename(path) + return unavailable("fs.path", path) +end + +--- Return the parent of the filepath. +--- +--- Examples: +--- +--- assert fs.dirname("/bin/bash") == "/bin" +--- assert fs.dirname("/") == "/" +--- assert fs.dirname("") == "" +--- assert fs.dirname("c:\\path") == "" -- on linux +--- +--- @param path string filepath +--- @return string # parent of file without basename +--- @nodiscard +function fs.dirname(path) + return unavailable("fs.dirname", path) +end + +--- Return the normalized filepath. +--- +--- Examples: +--- +--- assert fs.normalize("/bin/../bin/bash") == "/bin/bash" +--- assert fs.normalize("/no/exist//.//../exists") == "/no/exists" +--- +--- @param path string filepath +--- @return string # normalized file +--- @nodiscard +function fs.normalize(path) + return unavailable("fs.normalize", path) +end + +--- Return the true filepath, resolving symlinks and normalizing. +--- If the file does not exist, an empty string is returned. +--- +--- Examples: +--- +--- assert fs.realpath("/bin/../bin/bash") == "/usr/bin/bash" +--- assert fs.realpath("/no/exist") == "" +--- +--- @param path string filepath +--- @return string # real path of file +--- @nodiscard +function fs.realpath(path) + return unavailable("fs.realpath", path) +end + +--- Return the left and right arguments joined together. +--- +--- @param left string filepath +--- @param right string filepath +--- @return string # filepaths joined as "left/right" +--- @nodiscard +function fs.join(left, right) + return unavailable("fs.join", left, right) +end + +--- Return whether path is an absolute path. +--- +--- @param path string filepath to check +--- @return boolean # true if path is absolute +--- @nodiscard +function fs.is_absolute(path) + return unavailable("fs.is_absolute", path) +end + +--- Return whether path is a relative path. +--- +--- @param path string filepath to check +--- @return boolean # true if path is relative +--- @nodiscard +function fs.is_relative(path) + return unavailable("fs.is_relative", path) +end + +--- Return whether path refers to an existing directory. +--- +--- Symlinks are resolved, hence is_dir(path) and is_symlink(path) +--- can both be true. +--- +--- @param file string filepath to check +--- @return boolean # true if path exists and is a directory +--- @nodiscard +function fs.is_dir(file) + return unavailable("fs.is_dir", file) +end + +--- Return whether path refers to an existing normal file. +--- +--- A normal file excludes block devices, pipes, sockets, etc. +--- For these files, use is_other() or exists(). +--- Symlinks are resolved, hence is_file(path) and is_symlink(path) +--- can both be true. +--- +--- @param file string filepath to check +--- @return boolean # true if path exists and is a normal file +--- @nodiscard +function fs.is_file(file) + return unavailable("fs.is_file", file) +end + +--- Return whether path refers to an existing symlink. +--- +--- @param file string filepath to check +--- @return boolean # true if path exists and is a symlink +--- @nodiscard +function fs.is_symlink(file) + return unavailable("fs.is_symlink", file) +end + +--- Return whether path refers to something that exists, +--- but is not a file, directory, or symlink. +--- +--- This can be the case if it is a block device, pipe, socket, etc. +--- +--- @param file string filepath to check +--- @return boolean # true if path exists and is not a normal file, symlink, or directory +--- @nodiscard +function fs.is_other(file) + return unavailable("fs.is_other", file) +end + +--- Return whether path refers to something that exists, +--- regardless what it is. +--- +--- @param file string filepath to check +--- @return boolean # true if path exists +--- @nodiscard +function fs.exists(file) + return unavailable("fs.is_other", file) +end + +return fs diff --git a/engine/lua/cloe-engine/init.lua b/engine/lua/cloe-engine/init.lua new file mode 100644 index 000000000..4ca995df6 --- /dev/null +++ b/engine/lua/cloe-engine/init.lua @@ -0,0 +1,216 @@ +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- + +--- +--- @meta cloe-engine +--- +--- This file contains the type annotations of the `cloe-engine` module, +--- which are exported by the cloe-engine executable. +--- +--- These methods should only be used by the cloe library. +--- + +local engine = { + --- Contains data that will be processed at simulation start, + --- but will not be considered afterward. + initial_input = { + --- @type TriggerConf[] Initial set of triggers to insert into simulation. + triggers = {}, + + --- @type number Number of triggers processed from the initial input. + triggers_processed = 0, + }, + + --- Contains engine state for a simulation. + state = { + --- @type StackConf The current active stack configuration (volatile). + config = {}, + + --- @type table A table of feature flags. + features = { + ["cloe-0.18.0"] = true, + ["cloe-0.18"] = true, + ["cloe-0.19.0"] = true, + ["cloe-0.19"] = true, + ["cloe-0.20.0"] = true, + ["cloe-0.20"] = true, + + ["cloe-stackfile"] = true, + ["cloe-stackfile-4"] = true, + ["cloe-stackfile-4.0"] = true, + ["cloe-stackfile-4.1"] = true, + + ["cloe-server"] = false, + }, + + --- @type table Lua table dumped as JSON report at end of simulation. + report = {}, + + --- @type Coordinator|nil Reference to simulation trigger coordinator type. + scheduler = nil, + + --- @type Stack Reference to simulation stack type. + stack = nil, + + --- @type string|nil Path to currently executing Lua script file. + current_script_file = nil, + + --- @type string|nil Path to directory containing currently executing Lua script file. + current_script_dir = nil, + + --- @type string[] List of Lua scripts that have so far been processed. + scripts_loaded = {}, + }, + + --- @type table Namespaced Lua interfaces of instantiated plugins. + plugins = {}, +} + +require("cloe-engine.types") + +--- Fail with an error message that cloe-engine functionality not available. +--- +--- @param fname string +--- @param ... any Consumed but not used +--- @return any +local function unavailable(fname, ...) + local inspect = require("inspect").inspect + local buf = "cloe-engine." .. fname .. "(" + for i, v in ipairs(...) do + if i ~= 1 then + buf = buf .. ", " + end + buf = buf .. inspect(v) + end + buf = buf .. ")" + error(string.format("error: %s: implementation unavailable outside cloe-engine", buf)) +end + +--- Return two-character string representation of log-level. +--- +--- @param level string +--- @return string +--- @nodiscard +local function log_level_format(level) + if level == "info" then + return "II" + elseif level == "debug" then + return "DD" + elseif level == "warn" then + return "WW" + elseif level == "error" then + return "EE" + elseif level == "critical" then + return "CC" + elseif level == "trace" then + return "TT" + else + return "??" + end +end + +--- Return whether the engine is available. +--- +--- This is not the case when a Lua script is being run with +--- another interpreter, a REPL, or a language server. +--- +--- @return boolean +function engine.is_available() + return false +end + +--- Return path to Lua file that the engine is currently merging, +--- or nil if no file is being loaded. +--- +--- @return string|nil +function engine.get_script_file() + return engine.state.current_script_file +end + +--- Return path to directory containing the Lua file that the engine is +--- currently merging, or nil if no file is being loaded. +--- +--- @return string|nil +function engine.get_script_dir() + return engine.state.current_script_dir +end + +--- Return the global Stack instance. +--- +--- @return Stack +function engine.get_stack() + return unavailable("get_stack") +end + +--- Return the simulation scheduler (aka. Coordinator) global instance. +--- +--- @return Coordinator +function engine.get_scheduler() + return unavailable("get_scheduler") +end + +--- Return the simulation report. +--- +--- @return table +function engine.get_report() + return engine.state.report +end + +--- Return a table of available features. +--- +--- @return table +function engine.get_features() + return engine.state.features +end + +--- Log a message. +--- +--- @param level string +--- @param prefix string +--- @param message string +--- @return nil +function engine.log(level, prefix, message) + print(string.format("%s %s [%s] %s", log_level_format(level), os.date("%T"), prefix, message)) +end + +--- @class CommandSpecA +--- @field path string name or path of executable +--- @field args table list of arguments +--- @field mode? string execution mode (one of "sync", "async", "detach") +--- @field log_output? string output verbosity ("never", "on_error", "always") +--- @field ignore_failure? boolean whether to ignore failure + +--- @class CommandSpecB +--- @field command string command or script to run with default shell +--- @field mode? string execution mode (one of "sync", "async", "detach") +--- @field log_output? string output verbosity ("never", "on_error", "always") +--- @field ignore_failure? boolean whether to ignore failure + +--- @alias CommandSpecC string command or script to run with default shell + +--- @alias CommandSpec (CommandSpecA | CommandSpecB | CommandSpecC) + +--- Run a system command with the cloe executer. +--- +--- @param spec CommandSpec +--- @return string,number +function engine.exec(spec) + return unavailable("exec", spec), 1 +end + +return engine diff --git a/engine/lua/cloe-engine/types.lua b/engine/lua/cloe-engine/types.lua new file mode 100644 index 000000000..04ad3b531 --- /dev/null +++ b/engine/lua/cloe-engine/types.lua @@ -0,0 +1,313 @@ +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- + +--- +--- @meta cloe-engine.types +--- +--- This file contains the type annotations of the `cloe-engine.types` module, +--- which are exported by the cloe-engine executable. +--- +--- These methods should only be used by the cloe library. +--- + +--- @class Stack +local Stack = {} + +--- @class InputConf applied input stack configuration +--- @field file string source of stack, can be "" if unknown or "-" if stdin +--- @field data StackConf the contents of the input stack configuration + +--- @class StackConf stack configuration +--- @field version string version of stack (should be "4") +--- @field include? string[] list of files to include +--- @field engine? EngineConf engine configuration +--- @field simulation? SimulationConf simulation configuration +--- @field server? ServerConf server configuration +--- @field plugins? PluginConf[] list of plugin configurations +--- @field defaults? DefaultsConf default arguments to apply to plugins +--- @field simulators? SimulatorConf[] simulator configuration +--- @field controllers? ControllerConf[] controller configuration +--- @field vehicles? VehicleConf[] vehicle configuration +--- @field triggers? TriggerConf[] triggers to schedule + +--- @class EngineConf +--- @field hooks? { pre_connect?: CommandSpec[], post_disconnect?: CommandSpec[] } +--- @field ignore? string[] fields to ignore in input +--- @field keep_alive? boolean whether to keep cloe-engine alive after simulation end +--- @field output? EngineOutputConf output configuration +--- @field plugin_path? string[] list of plugin files to load +--- @field plugins? { allow_clobber?: boolean, ignore_failure?: boolean, ignore_missing?: boolean } +--- @field polling_interval? number how many milliseconds to wait in pause state (default: 100) +--- @field registry_path? string path to use for registry (where output is also written) +--- @field security? EngineSecurityConf security configuration +--- @field triggers? EngineTriggerConf trigger configuration +--- @field watchdog? EngineWatchdogConf watchdog configuration + +--- @class EnginePluginConf +--- @field allow_clobber? boolean whether to allow a plugin to override a previously loaded plugin +--- @field ignore_failure? boolean whether to ignore plugin loading failure (e.g. not a cloe plugin) +--- @field ignore_missing? boolean whether to ignore plugins that are specified but missing + +--- @class EngineTriggerConf +--- @field ignore_source? boolean whether to ignore the "source" field into account + +--- @class EngineOutputConf +--- @field path? string directory prefix for all output files (relative to registry_path) +--- @field clobber? boolean whether to overwrite pre-existing files (default: true) +--- @field files? EngineOutputFilesConf configuration for each output file + +--- @class EngineOutputFilesConf +--- @field config? string simulation configuration (result of defaults and loaded configuration) +--- @field result? string simulation result and report +--- @field triggers? string list of applied triggers +--- @field signals? string list of signals +--- @field signals_autocompletion? string signal autocompletion file for Lua +--- @field api_recording? string data stream recording file + +--- @class EngineSecurityConf +--- @field enable_command_action? boolean whether to allow commands (default: true) +--- @field enable_hooks_section? boolean whether to allow hooks to run (default: true) +--- @field enable_include_section? boolean whether to allow files to include other files (default: true) +--- @field max_include_depth? number how many includes deep we can do before aborting (default: 64) + +--- @class EngineWatchdogConf +--- @field default_timeout? number in [milliseconds] +--- @field mode? string one of "off", "log", "abort", "kill" (default: "off") +--- @field state_timeouts? table timeout values for specific engine states +--- +--- @class LoggingConf +--- @field name string name of logger, e.g. "cloe" +--- @field pattern? string pattern to use for logging output +--- @field level? string one of "debug", "trace", "info", "warn", "error", "critical" + +--- @class ServerConf +--- @field listen? boolean whether to enable the server (default: true) +--- @field listen_address? string address to listen on (default: "127.0.0.1") +--- @field listen_port? number port to listen on (default: 8080) +--- @field listen_threads? number threads to use (deprecated) +--- @field api_prefix? string endpoint prefix for API endpoints (default: "/api") +--- @field static_prefix? string endpoint prefix for static assets (default: "") + +--- @class PluginConf +--- @field path string path to plugin or directory to load +--- @field name? string name to load plugin as (used with binding field later) +--- @field prefix? string apply prefix to plugin name (useful for directories) +--- @field ignore_missing? boolean ignore plugin if missing +--- @field ignore_failure? boolean ignore plugin if cannot load +--- @field allow_clobber? boolean allow plugin to overwrite previously loaded of same name + +--- @class SimulatorConf +--- @field binding string plugin name +--- @field name? string simulator name, defaults to plugin name +--- @field args? table simulator configuration (plugin specific) + +--- @class TriggerConf +--- @field action string|table|fun(sync: Sync):(boolean?) +--- @field event string|table +--- @field label? string +--- @field source? string +--- @field sticky? boolean +--- @field conceal? boolean +--- @field optional? boolean +--- @field group? string + +--- @class VehicleConf +--- @field name string vehicle name, used in controller configuration +--- @field from VehicleFromSimConf|string vehicle data source +--- @field components? table component configuration + +--- @class VehicleFromSimConf +--- @field simulator string simulator name +--- @field index? number vehicle index +--- @field name? string vehicle name + +--- @class ControllerConf +--- @field binding string plugin name +--- @field name? string controller name, defaults to plugin name +--- @field vehicle string vehicle to attach to (name in vehicle conf) +--- @field args? table controller configuration (plugin specific) + +--- @class SimulationConf +--- @field abort_on_controller_failure? boolean whether to abort when controller fails (default: true) +--- @field controller_retry_limit? number how many times to let controller attempt to make progress (default: 1024) +--- @field controller_retry_sleep? number how long to wait between controller attempts, in [milliseconds] +--- @field model_step_width? number how long a single cycle lasts in the simulation, in [nanoseconds] + +--- @class ComponentConf +--- @field binding string plugin name +--- @field from string[]|string source components to use as input +--- @field name? string name to use for component, defaults to plugin name +--- @field args? table component configuration (plugin specific) + +--- @class DefaultConf +--- @field name? string name to match +--- @field binding? string binding to match +--- @field args table default arguments to apply (can be overridden) + +--- @class DefaultsConf +--- @field components DefaultConf[] defaults for components +--- @field simulators DefaultConf[] defaults for simulators +--- @field controllers DefaultConf[] defaults for controllers + +--- Merge JSON stackfile into simulation configuration. +--- +--- @param filepath string +--- @return nil +function Stack:merge_stackfile(filepath) end + +--- Merge JSON string into simulation configuration. +--- +--- @param json string Input JSON (use Lua multiline feature) +--- @param source_filepath string Filepath to use for error messages +--- @return nil +function Stack:merge_stackjson(json, source_filepath) end + +--- Merge Lua table into simulation configuration. +--- +--- This converts the table to JSON, then loads it. +--- +--- @param tbl StackConf Input JSON as Lua table +--- @param source_filepath string Filepath to use for error messages +--- @return nil +function Stack:merge_stacktable(tbl, source_filepath) end + +--- Return the current active configuration of the stack file. +--- +--- This is not the same thing as the input configuration! +--- +--- @return StackConf +function Stack:active_config() end + +--- Return an array of input configuration of the stack file. +--- +--- This is not the same thing as the active configuration! +--- +--- @return InputConf[] +function Stack:input_config() end + +--- @class Duration +local Duration = {} + +--- Return new Duration instance from duration format. +--- +--- @param format string Duration such as "1s" or "1.5 ms" +--- @return Duration +function Duration.new(format) end + +--- Return Duration as nanoseconds. +--- +--- @return number nanoseconds +function Duration:ns() end + +--- Return Duration as microseconds. +--- +--- @return number microseconds +function Duration:us() end + +--- Return Duration as milliseconds. +--- +--- @return number milliseconds +function Duration:ms() end + +--- Return Duration as seconds. +--- +--- @return number seconds +function Duration:s() end + +--- @class Sync +local Sync = {} + +--- Return current simulation step. +--- +--- @return integer +--- @nodiscard +function Sync:step() end + +--- Return simulation step_width. +--- +--- @return Duration +--- @nodiscard +function Sync:step_width() end + +--- Return current simulation time. +--- +--- @return Duration +--- @nodiscard +function Sync:time() end + +--- Return estimated simulation end. +--- +--- If unknown, then 0 is returned. +--- +--- @return Duration +--- @nodiscard +function Sync:eta() end + +--- Return current simulation realtime-factor target. +--- +--- @return number +--- @nodiscard +function Sync:realtime_factor() end + +--- Return whether realtime-factor target is unlimited. +--- +--- If true, then the simulation runs as fast as possible and never pads +--- cycles with waiting time. +--- +--- @return boolean +--- @nodiscard +function Sync:is_realtime_factor_unlimited() end + +--- Return estimated achievable simulation realtime-factor target. +--- +--- @return number +--- @nodiscard +function Sync:achievable_realtime_factor() end + +--- @class Coordinator +local Coordinator = {} + +--- Insert a trigger into the coordinator event queue. +--- +--- @param trigger TriggerConf trigger schema to insert +--- @return nil +function Coordinator:insert_trigger(trigger) end + +--- Execute an action known to the coordinator immediately. +--- +--- @param action string|table action schema to insert +--- @return nil +function Coordinator:execute_action(action) end + +--- @enum LogLevel +local LogLevel = { + TRACE = "trace", + DEBUG = "debug", + INFO = "info", + WARN = "warn", + ERROR = "error", + CRITICAL = "critical", +} + +return { + Stack = Stack, + Duration = Duration, + Sync = Sync, + Coordinator = Coordinator, + LogLevel = LogLevel, +} diff --git a/engine/lua/cloe/actions.lua b/engine/lua/cloe/actions.lua new file mode 100644 index 000000000..64d9b3aa6 --- /dev/null +++ b/engine/lua/cloe/actions.lua @@ -0,0 +1,106 @@ +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- + +local types = require("cloe-engine.types") +local validate = require("cloe.typecheck").validate + +local actions = {} + +--- Stop the simulation. +function actions.stop() + return "stop" +end + +--- Stop the simulation and mark the outcome as failed. +function actions.fail() + return "fail" +end + +--- Stop the simulation and mark the outcome as success. +function actions.succeed() + return "succeed" +end + +--- Insert a trigger at this time. +--- +--- @param triggers TriggerConf[] +function actions.insert(triggers) + return { + name = "insert", + items = triggers, + } +end + +--- Keep simulation alive after termination. +--- +--- This can be useful if you still want to access the web server. +function actions.keep_alive() + return "keep_alive" +end + +--- Run a command on the system. +--- +--- @deprecated Use a Lua function with `cloe.system.exec()`. +--- @param cmd string +--- @param options? { ignore_failure?: boolean, log_output?: string, mode?: string } +function actions.command(cmd, options) + validate("cloe.actions.command(string)", cmd) + local trigger = options or {} + trigger.name = "command" + trigger.command = cmd + return trigger +end + +--- Log a message with the cloe logging framework. +--- +--- @deprecated Use a Lua function with `cloe.log()`. +--- @param level? LogLevel +--- @param msg string +function actions.log(level, msg) + validate("cloe.actions.log(string?, string)", level, msg) + return { + name = "log", + level = level, + msg = msg, + } +end + +--- Lua string to execute. +--- +--- This is not the recommended way to run Lua as an action. +--- +--- @deprecated Use a Lua function directly. +--- @param s string +function actions.lua(s) + return { + name = "lua", + script = s, + } +end + +--- Realtime factor to apply to simulation speed. +--- +--- @param factor number where -1 is infinite speed, 0 is invalid, and 1.0 is realtime +function actions.realtime_factor(factor) + return { + name = "realtime_factor", + realtime_factor = factor, + } +end + +return actions diff --git a/engine/lua/cloe/engine.lua b/engine/lua/cloe/engine.lua new file mode 100644 index 000000000..8a3e86df6 --- /dev/null +++ b/engine/lua/cloe/engine.lua @@ -0,0 +1,378 @@ +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- + +local api = require("cloe-engine") +local fs = require("cloe-engine.fs") +local luax = require("cloe.luax") + +local typecheck = require("cloe.typecheck") +local validate, validate_shape = typecheck.validate, typecheck.validate_shape + +--- Let the language-server know we are importing cloe-engine.types into engine: +---@module 'cloe-engine.types' +local engine = {} + +-- Import all types from cloe-engine into this namespace. +for k, v in pairs(require("cloe-engine.types")) do + engine[k] = v +end + +--- Return if Cloe has feature as defined by string. +--- +--- @param id string feature identifier, such as `cloe-0.20` +--- @return boolean +--- @nodiscard +function engine.has_feature(id) + validate("cloe.has_feature(string)", id) + return api.state.features[id] and true or false +end + +--- Throw an exception if Cloe does not have feature as defined by string. +--- +--- @param id string feature identifier, such as `cloe-0.20` +--- @return nil +function engine.require_feature(id) + validate("cloe.require_feature(string)", id) + if not engine.has_feature(id) then + error("required feature not available: " .. id) + end +end + +--- Return the active stack configuration as a table. +--- +--- Modifying the values here have no effect. It is simply a dump +--- of the JSON representation of a stack configuration. +--- +--- @return StackConf +function engine.config() + return api.state.config +end + +--- Try to load (merge) stackfile. +--- +--- @param file string file path, possibly relative to calling file +--- @return Stack +function engine.load_stackfile(file) + validate("cloe.load_stackfile(string)", file) + local cwd = api.state.current_script_dir or "." + if fs.is_relative(file) then + file = cwd .. "/" .. file + end + api.state.stack:merge_stackfile(file) + return api.state.stack +end + +--- Read JSON file into Lua types (most likely as Lua table). +--- +--- @param file string file path +--- @return any # JSON converted into Lua types +--- @nodiscard +function engine.open_json(file) + validate("cloe.open_json(string)", file) + local fp = io.open(file, "r") + if not fp then + error("cannot open file: " .. file) + end + local data = fp:read("*all") + local json = require("json") + return json:decode(data) +end + +--- Try to apply the supplied table to the stack. +--- +--- @param stack StackConf|string stack format as Lua table (or JSON string) +--- @return nil +function engine.apply_stack(stack) + validate("cloe.apply_stack(string|table)", stack) + local file = api.state.current_script_file or "" + if type(stack) == "table" then + api.state.stack:merge_stacktable(stack --[[ @as table ]], file) + else + api.state.stack:merge_stackjson(stack --[[ @as string ]], file) + end +end + +--- Log a message with a given severity. +--- +--- For example: +--- +--- cloe.log("info", "Got value of %d, expected %d", 4, 6) +--- cloe.log(cloe.LogLevel.WARN, "Got value of %s, expected %s", 4, 6) +--- +--- @param level LogLevel|string severity level, one of: trace, debug, info, warn, error, critical +--- @param fmt string format string with trailing arguments compatible with string.format +--- @param ... any arguments to format string +--- @return nil +function engine.log(level, fmt, ...) + validate("cloe.log(string, string, [?any]...)", level, fmt, ...) + local msg = string.format(fmt, ...) + api.log(level, "lua", msg) +end + +--- Schedule a trigger. +--- +--- It is not recommended to use this low-level function, as it is viable to change. +--- Instead, use one of the following functions: +--- - `cloe.schedule()` +--- - `cloe.schedule_these()` +--- - `cloe.schedule_test()` +--- +--- @param trigger TriggerConf +--- @return nil +function engine.insert_trigger(trigger) + -- A Lua script runs before a scheduler is started, so the initial + -- events are put in a queue and picked up by the engine at simulation + -- start. After this, cloe.state.scheduler exists and we can use its + -- methods. + if api.state.scheduler then + api.state.scheduler:insert_trigger(trigger) + else + table.insert(api.initial_input.triggers, trigger) + end +end + +--- Execute a trigger action directly. +--- +--- This is useful when you need to do something but can't wait for +--- a new simulation cycle. Note that not all actions are instantaneous. +--- +--- @param action string|table +--- @return nil +function engine.execute_action(action) + validate("cloe.execute_action(string|table)", action) + if api.state.scheduler then + api.state.scheduler:execute_action(action) + else + error("can only execute actions within scheduled events") + end +end + +--- @alias EventFunction fun(sync: Sync):boolean + +--- @alias ActionFunction fun(sync: Sync):boolean? + +--- @class Task +--- @field on string|table|EventFunction what event to trigger on (required) +--- @field run string|table|ActionFunction what to do when the event triggers (required) +--- @field desc? string description of what the trigger is about (default: empty) +--- @field enable? boolean|fun():boolean whether to schedule the trigger or not (default: true) +--- @field group? string whether to assign a group to this trigger (default: nil) +--- @field pin? boolean whether the trigger remains after being run (default: false) +--- @field priority? integer priority to use when multiple events occur simultaneously (currently unimplemented) +--- @field source? string where to the trigger is defined (defined automatically) +local Task +do + local types = require("tableshape").types + Task = types.shape { + on = types.string + types.table + types.func, + run = types.string + types.table + types.func, + desc = types.string:is_optional(), + enable = types.boolean:is_optional(), + group = types.string:is_optional(), + pin = types.boolean:is_optional(), + priority = types.integer:is_optional(), + source = types.string:is_optional(), + } +end + +--- @class PartialTask +--- @field on? string|table|EventFunction what event to trigger on (required) +--- @field run? string|table|ActionFunction what to do when the event triggers (required) +--- @field desc? string description of what the trigger is about (default: empty) +--- @field enable? boolean|fun():boolean whether to schedule the trigger or not (default: true) +--- @field group? string whether to assign a group to this trigger (default: nil) +--- @field pin? boolean whether the trigger remains after being run (default: false) +--- @field priority? integer priority to use when multiple events occur simultaneously (currently unimplemented) +--- @field source? string where to the trigger is defined (defined automatically) +local PartialTask +local PartialTaskSpec +do + local types = require("tableshape").types + PartialTaskSpec = { + on = (types.string + types.table + types.func):is_optional(), + run = (types.string + types.table + types.func):is_optional(), + desc = types.string:is_optional(), + enable = types.boolean:is_optional(), + group = types.string:is_optional(), + pin = types.boolean:is_optional(), + priority = types.integer:is_optional(), + source = types.string:is_optional(), + } + PartialTask = types.shape(PartialTaskSpec) +end + +--- @class Tasks: PartialTask +--- @field [number] PartialTask an array of tasks, falling back to defaults specified above +local Tasks +do + local types = require("tableshape").types + Tasks = types.shape( + PartialTaskSpec, + { + extra_fields = types.array_of(PartialTask) + } + ) +end + +--- Expand a list of partial tasks to a list of complete tasks. +--- +--- @param tasks Tasks +--- @return Task[] +--- @nodiscard +function engine.expand_tasks(tasks) + local results = {} + for _, partial in ipairs(tasks) do + local task = { + on = partial.on or tasks.on, + run = partial.run or tasks.run, + enable = partial.enable == nil and tasks.enable or partial.enable, + group = partial.group or tasks.group, + priority = partial.priority or tasks.priority, + pin = partial.pin == nil and tasks.pin or partial.pin, + desc = partial.desc or tasks.desc, + } + table.insert(results, task) + end + return results +end + +--- Return whether the task is enabled. +--- +--- @param spec Task|Test +--- @return boolean +--- @nodiscard +local function is_task_enabled(spec) + local default = true + if spec.enable == nil then + return default + elseif type(spec.enable) == "boolean" then + return spec.enable --[[@as boolean]] + elseif type(spec.enable) == "function" then + return spec.enable() + else + error("enable: invalid type, expect boolean|fun(): boolean") + end +end + +--- Schedule a task (i.e., event-action pair). +--- +--- @param task Task +--- @return boolean # true if schedule +function engine.schedule(task) + validate_shape("cloe.schedule(Task)", Task, task) + if not is_task_enabled(task) then + return false + end + + local event = task.on + local action = task.run + local action_source = task.source + if not action_source and type(action) == "function" then + local debinfo = debug.getinfo(action) + action_source = string.format("%s:%s-%s", debinfo.short_src, debinfo.linedefined, debinfo.lastlinedefined) + end + + -- TODO: Replace this with proper Lua function events + local pin = task.pin or false + if type(event) == "function" then + local old_event = event + local old_action = action + local old_pin = pin + pin = true + event = "loop" + action = function(sync) + if old_event(sync) then + if type(old_action) == "function" then + old_action(sync) + else + -- TODO: Maybe this works for functions too + engine.execute_action(old_action) + end + return old_pin + end + end + end + + local group = task.group or "" + local priority = task.priority or 100 + + engine.insert_trigger({ + label = task.desc, + event = event, + action = action, + action_source = action_source, + sticky = pin, + priority = priority, + group = group, + }) + return true +end + +--- Schedule one or more event-action pairs, +--- with defaults specified as keys inline. +--- +--- @param tasks Tasks tasks to schedule +--- @return boolean[] # list mapping whether each task was scheduled +function engine.schedule_these(tasks) + validate_shape("cloe.schedule_these(Tasks)", Tasks, tasks) + local results = {} + for _, task in ipairs(engine.expand_tasks(tasks)) do + local result = engine.schedule(task) + table.insert(results, result) + end + return results +end + +--- @class Test +--- @field id string unique identifier to use for test (required) +--- @field on string|EventFunction when to start the test execution (required) +--- @field run fun(z: TestFixture, sync: Sync) test definition (required) +--- @field desc? string description of what the test is about (default: empty) +--- @field info? table metadata to include in the test report (default: nil) +--- @field enable? boolean|fun():boolean whether the test should be scheduled (default: true) +--- @field terminate? boolean|fun():boolean whether to automatically terminate simulation if this is last test run (default: true) +local Test +do + local types = require("tableshape").types + Test = types.shape { + id = types.string, + on = types.string + types.table + types.func, + run = types.string + types.table + types.func, + desc = types.string:is_optional(), + info = types.table:is_optional(), + enable = types.boolean:is_optional(), + terminate = types.boolean:is_optional(), + } +end + +--- Schedule a test as a coroutine that can yield to Cloe. +--- +--- @param test Test test specification (requires fields: id, on, run) +function engine.schedule_test(test) + validate_shape("cloe.schedule_test(Test)", Test, test) + if not is_task_enabled(test) then + return false + end + + --- We don't want users to see private method `schedule_self()`, + --- but we need to use it here to actually schedule the test. + --- @diagnostic disable-next-line: invisible + require("cloe.testing").TestFixture.new(test):schedule_self() +end + +return engine diff --git a/engine/lua/cloe/events.lua b/engine/lua/cloe/events.lua new file mode 100644 index 000000000..baf94a8fa --- /dev/null +++ b/engine/lua/cloe/events.lua @@ -0,0 +1,195 @@ +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- + +local api = require("cloe-engine") +local types = require("cloe-engine.types") +local validate = require("cloe.typecheck").validate + +local events = {} + +--- Event to be used for the on key in schedule_test. +--- +--- Example: +--- +--- cloe.schedule_test { +--- id = "TEST-A", +--- on = cloe.events.start(), +--- -- ... +--- } +--- +--- cloe.schedule_test { +--- id = "TEST-B", +--- on = cloe.events.after_tests("TEST-A"), +--- -- ... +--- } +--- +--- cloe.schedule_test { +--- id = "FINAL", +--- on = cloe.events.after_tests("TEST-A", "TEST-B"), +--- -- ... +--- } +--- +--- @param ... string tests to wait for +--- @return fun():boolean +function events.after_tests(...) + validate("cloe.events.after_tests(string...)", ...) + local names = { ... } + + if #names == 1 then + local name = names[1] + return function() + return api.state.report.tests[name].complete + end + else + return function() + for _, k in ipairs(names) do + if not api.state.report.tests[k].complete then + return false + end + end + return true + end + end +end + +--- Schedule every duration, starting with 0. +--- +--- Note: You have to pin the schedule otherwise it will be descheduled +--- after running once. +--- +--- @param duration string|Duration +function events.every(duration) + validate("cloe.events.every(string|userdata)", duration) + if type(duration) == "string" then + duration = types.Duration.new(duration) + end + if duration:ns() % api.state.config.simulation.model_step_width ~= 0 then + error("interval duration is not a multiple of nominal step width") + end + return function(sync) + return sync:time():ms() % duration:ms() == 0 + end +end + +--- When the simulation is starting. +function events.start() + return "start" +end + +--- When the simulation has stopped. +function events.stop() + return "stop" +end + +--- When the simulation is marked as a fail (after stopping). +function events.failure() + return "failure" +end + +--- When the simulation is marked as a pass (after stopping). +function events.success() + return "success" +end + +--- Every loop. +--- +--- Note: You have to pin the schedule otherwise it will be descheduled +--- after running once. +function events.loop() + return "loop" +end + +--- Schedule for absolute simulation time specified. +--- +--- Warning: If the specified time is in the past, then the behavior is *undefined*. +--- +--- @param simulation_time string|Duration +function events.time(simulation_time) + validate("cloe.events.every(string|userdata)", simulation_time) + if type(simulation_time) == "string" then + simulation_time = types.Duration.new(simulation_time) + end + return string.format("time=%s", simulation_time:s()) +end + +--- Schedule for next cycle after specified duration. +--- +--- @param simulation_duration? string|Duration +function events.next(simulation_duration) + validate("cloe.events.next([string|userdata])", simulation_duration) + if not simulation_duration then + return "next" + end + + if type(simulation_duration) == "string" then + simulation_duration = types.Duration.new(simulation_duration) + end + return string.format("next=%s", simulation_duration:s()) +end + +--- When the simulation is paused. +--- +--- This will trigger every few milliseconds while in the pause state. +function events.pause() + return "pause" +end + +--- When the simulation resumes after pausing. +function events.resume() + return "resume" +end + +--- When the simulation is reset. +--- +--- @deprecated Currently this behavior is unsupported. +function events.reset() + return "reset" +end + +--- When condition() is true or after timeout duration has elapsed. +--- +--- NOTE: Currently it is not possible to easily determine if the +--- event is triggering because of a timeout or because the condition +--- evaluated to true. +--- +--- Example: +--- +--- cloe.schedule { +--- on = cloe.events.with_timeout(nil, function(sync) return cloe.signal("SIG_A") == 5 end, "10s"), +--- action = cloe.actions.succeed(), +--- } +--- +--- @param current_sync? Sync current Sync, possibly nil +--- @param condition fun(sync: Sync):boolean +--- @param timeout string|Duration time to wait until giving up +function events.with_timeout(current_sync, condition, timeout) + if type(timeout) == "string" then + timeout = types.Duration.new(timeout) + end + if current_sync then + timeout = current_sync:time() + timeout + end + return function(sync) + if sync:time() > timeout then + return true + end + return condition(sync) + end +end + +return events diff --git a/engine/lua/cloe/init.lua b/engine/lua/cloe/init.lua new file mode 100644 index 000000000..96ca5968d --- /dev/null +++ b/engine/lua/cloe/init.lua @@ -0,0 +1,121 @@ +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- + +local api = require("cloe-engine") +local engine = require("cloe.engine") + +---@module 'cloe.engine' +local cloe = {} + +-- Re-export everything from cloe.engine into cloe. +for k, v in pairs(engine) do + if cloe[k] then + error("duplicate function definition in cloe.engine") + end + cloe[k] = v +end + +--- Table of functions for dealing with file paths. +cloe.fs = require("cloe-engine.fs") + +--- Table of common events for use with tasks and tests. +cloe.events = require("cloe.events") + +--- Table of common actions for use with tasks. +cloe.actions = require("cloe.actions") + +--- Validate input arguments of a function in a single line. +--- +--- This is basically a specialized version of the typecheck.argscheck +--- function, in that it does not wrap the original function, +--- thereby preserving the type data that the Lua language server +--- uses to provide hints and autocompletion. +--- +--- @see cloe.typecheck.validate +cloe.validate = require("cloe.typecheck").validate + +--- Validate the shape (from tableshape) of a table or type. +--- +--- @see cloe.typecheck.validate_shape +cloe.validate_shape = require("cloe.typecheck").validate_shape + +--- Return a human-readable representation of a Lua object. +--- +--- This is primarily used for debugging and should not be used +--- when performance is important. It is a table, but acts as a +--- function. +--- +--- For more details, see: https://github.com/kikito/inspect.lua +--- +--- @see inspect.inspect +cloe.inspect = require("inspect").inspect + +--- Print a human-readable representation of a Lua object. +--- +--- This just prints the output of inspect. +--- +--- For more details, see: https://github.com/kikito/inspect.lua +--- +--- @param root any +--- @param options? table +--- @return nil +function cloe.describe(root, options) + print(require("inspect").inspect(root, options)) +end + +--- Require a module, prioritizing modules relative to the script +--- launched by cloe-engine. +--- +--- If api.state.current_script_dir is nil, this is equivalent to require(). +--- +--- @param module string module identifier, such as "project" +function cloe.require(module) + cloe.validate("cloe.require(string)", module) + local script_dir = api.state.current_script_dir + if script_dir then + local old_package_path = package.path + package.path = string.format("%s/?.lua;%s/?/init.lua;%s", script_dir, script_dir, package.path) + local module_table = require(module) + package.path = old_package_path + return module_table + else + engine.log("warn", "cloe.require() expects cloe-engine.get_script_dir() ~= nil, but it is not", nil) + return require(module) + end +end + +--- Initialize report metadata. +--- +--- @param header table Optional report header information that will be merged in. +--- @return table +function cloe.init_report(header) + cloe.validate("cloe.init_report(?table)", header) + local system = require("cloe.system") + local report = api.state.report + report.metadata = { + hostname = system.get_hostname(), + username = system.get_username(), + datetime = system.get_datetime(), + } + if header then + report.metadata = require("cloe.luax").tbl_deep_extend("force", report.metadata, header) + end + return report +end + +return cloe diff --git a/engine/lua/cloe/luax.lua b/engine/lua/cloe/luax.lua new file mode 100644 index 000000000..7d4406541 --- /dev/null +++ b/engine/lua/cloe/luax.lua @@ -0,0 +1,763 @@ +-- NOTICE +-- +-- This is a *derivative* work of source code from the Apache-2.0 licensed +-- Neovim project. +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- Copyright Neovim contributors. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- SPDX-License-Identifier: Apache-2.0 + +local luax = {} + +--- Returns a deep copy of the given object. Non-table objects are copied as +--- in a typical Lua assignment, whereas table objects are copied recursively. +--- Functions are naively copied, so functions in the copied table point to the +--- same functions as those in the input table. Userdata and threads are not +--- copied and will throw an error. +--- +---@param orig table Table to copy +---@return table Table of copied keys and (nested) values. +function luax.deepcopy(orig) end -- luacheck: no unused +luax.deepcopy = (function() + local function _id(v) + return v + end + + local deepcopy_funcs = { + table = function(orig, cache) + if cache[orig] then + return cache[orig] + end + local copy = {} + + cache[orig] = copy + local mt = getmetatable(orig) + for k, v in pairs(orig) do + copy[luax.deepcopy(k, cache)] = luax.deepcopy(v, cache) + end + return setmetatable(copy, mt) + end, + number = _id, + string = _id, + ['nil'] = _id, + boolean = _id, + ['function'] = _id, + } + + return function(orig, cache) + local f = deepcopy_funcs[type(orig)] + if f then + return f(orig, cache or {}) + else + error('Cannot deepcopy object of type ' .. type(orig)) + end + end +end)() + +--- Splits a string at each instance of a separator. +--- +---@see |luax.split()| +---@see https://www.lua.org/pil/20.2.html +---@see http://lua-users.org/wiki/StringLibraryTutorial +--- +---@param s string String to split +---@param sep string Separator or pattern +---@param plain boolean If `true` use `sep` literally (passed to string.find) +---@return function Iterator over the split components +function luax.gsplit(s, sep, plain) + luax.validate({ s = { s, 's' }, sep = { sep, 's' }, plain = { plain, 'b', true } }) + + local start = 1 + local done = false + + local function _pass(i, j, ...) + if i then + assert(j + 1 > start, 'Infinite loop detected') + local seg = s:sub(start, i - 1) + start = j + 1 + return seg, ... + else + done = true + return s:sub(start) + end + end + + return function() + if done or (s == '' and sep == '') then + return + end + if sep == '' then + if start == #s then + done = true + end + return _pass(start + 1, start) + end + return _pass(s:find(sep, start, plain)) + end +end + +--- Splits a string at each instance of a separator. +--- +--- Examples: +---
+---  split(":aa::b:", ":")     --> {'','aa','','b',''}
+---  split("axaby", "ab?")     --> {'','x','y'}
+---  split("x*yz*o", "*", {plain=true})  --> {'x','yz','o'}
+---  split("|x|y|z|", "|", {trimempty=true}) --> {'x', 'y', 'z'}
+--- 
+--- +---@see |luax.gsplit()| +--- +---@param s string String to split +---@param sep string Separator or pattern +---@param kwargs table Keyword arguments: +--- - plain: (boolean) If `true` use `sep` literally (passed to string.find) +--- - trimempty: (boolean) If `true` remove empty items from the front +--- and back of the list +---@return table List of split components +function luax.split(s, sep, kwargs) + luax.validate({ kwargs = { kwargs, 't', true } }) + kwargs = kwargs or {} + local plain = kwargs.plain + local trimempty = kwargs.trimempty + + local t = {} + local skip = trimempty + for c in luax.gsplit(s, sep, plain) do + if c ~= '' then + skip = false + end + + if not skip then + table.insert(t, c) + end + end + + if trimempty then + for i = #t, 1, -1 do + if t[i] ~= '' then + break + end + table.remove(t, i) + end + end + + return t +end + +--- Return a list of all keys used in a table. +--- However, the order of the return table of keys is not guaranteed. +--- +---@see From https://github.com/premake/premake-core/blob/master/src/base/table.lua +--- +---@param t table Table +---@return table List of keys +function luax.tbl_keys(t) + assert(type(t) == 'table', string.format('Expected table, got %s', type(t))) + + local keys = {} + for k, _ in pairs(t) do + table.insert(keys, k) + end + return keys +end + +--- Return a list of all values used in a table. +--- However, the order of the return table of values is not guaranteed. +--- +---@param t table Table +---@return table List of values +function luax.tbl_values(t) + assert(type(t) == 'table', string.format('Expected table, got %s', type(t))) + + local values = {} + for _, v in pairs(t) do + table.insert(values, v) + end + return values +end + +--- Apply a function to all values of a table. +--- +---@param func function|table Function or callable table +---@param t table Table +---@return table Table of transformed values +function luax.tbl_map(func, t) + luax.validate({ func = { func, 'c' }, t = { t, 't' } }) + + local rettab = {} + for k, v in pairs(t) do + rettab[k] = func(v) + end + return rettab +end + +--- Filter a table using a predicate function +--- +---@param func function|table Function or callable table +---@param t table Table +---@return table Table of filtered values +function luax.tbl_filter(func, t) + luax.validate({ func = { func, 'c' }, t = { t, 't' } }) + + local rettab = {} + for _, entry in pairs(t) do + if func(entry) then + table.insert(rettab, entry) + end + end + return rettab +end + +--- Checks if a list-like (vector) table contains `value`. +--- +---@param t table Table to check +---@param value any Value to compare +---@return boolean `true` if `t` contains `value` +function luax.tbl_contains(t, value) + luax.validate({ t = { t, 't' } }) + + for _, v in ipairs(t) do + if v == value then + return true + end + end + return false +end + +--- Checks if a table is empty. +--- +---@see https://github.com/premake/premake-core/blob/master/src/base/table.lua +--- +---@param t table Table to check +---@return boolean `true` if `t` is empty +function luax.tbl_isempty(t) + assert(type(t) == 'table', string.format('Expected table, got %s', type(t))) + return next(t) == nil +end + +--- We only merge empty tables or tables that are not a list +---@private +local function can_merge(v) + return type(v) == 'table' and (luax.tbl_isempty(v) or not luax.tbl_islist(v)) +end + +local function tbl_extend(behavior, deep_extend, ...) + if behavior ~= 'error' and behavior ~= 'keep' and behavior ~= 'force' then + error('invalid "behavior": ' .. tostring(behavior)) + end + + if select('#', ...) < 2 then + error( + 'wrong number of arguments (given ' + .. tostring(1 + select('#', ...)) + .. ', expected at least 3)' + ) + end + + local ret = {} + if luax._empty_dict_mt ~= nil and getmetatable(select(1, ...)) == luax._empty_dict_mt then + ret = luax.empty_dict() + end + + for i = 1, select('#', ...) do + local tbl = select(i, ...) + luax.validate({ ['after the second argument'] = { tbl, 't' } }) + if tbl then + for k, v in pairs(tbl) do + if deep_extend and can_merge(v) and can_merge(ret[k]) then + ret[k] = tbl_extend(behavior, true, ret[k], v) + elseif behavior ~= 'force' and ret[k] ~= nil then + if behavior == 'error' then + error('key found in more than one map: ' .. k) + end -- Else behavior is "keep". + else + ret[k] = v + end + end + end + end + return ret +end + +--- Merges two or more map-like tables. +--- +---@see |extend()| +--- +---@param behavior string Decides what to do if a key is found in more than one map: +--- - "error": raise an error +--- - "keep": use value from the leftmost map +--- - "force": use value from the rightmost map +---@param ... table Two or more map-like tables +---@return table Merged table +function luax.tbl_extend(behavior, ...) + return tbl_extend(behavior, false, ...) +end + +--- Merges recursively two or more map-like tables. +--- +---@see |luax.tbl_extend()| +--- +---@param behavior string Decides what to do if a key is found in more than one map: +--- - "error": raise an error +--- - "keep": use value from the leftmost map +--- - "force": use value from the rightmost map +---@param ... table Two or more map-like tables +---@return table Merged table +function luax.tbl_deep_extend(behavior, ...) + return tbl_extend(behavior, true, ...) +end + +--- Deep compare values for equality +--- +--- Tables are compared recursively unless they both provide the `eq` metamethod. +--- All other types are compared using the equality `==` operator. +---@param a any First value +---@param b any Second value +---@return boolean `true` if values are equals, else `false` +function luax.deep_equal(a, b) + if a == b then + return true + end + if type(a) ~= type(b) then + return false + end + if type(a) == 'table' then + for k, v in pairs(a) do + if not luax.deep_equal(v, b[k]) then + return false + end + end + for k, _ in pairs(b) do + if a[k] == nil then + return false + end + end + return true + end + return false +end + +--- Add the reverse lookup values to an existing table. +--- For example: +--- ``tbl_add_reverse_lookup { A = 1 } == { [1] = 'A', A = 1 }`` +--- +--- Note that this *modifies* the input. +---@param o table Table to add the reverse to +---@return table o +function luax.tbl_add_reverse_lookup(o) + local keys = luax.tbl_keys(o) + for _, k in ipairs(keys) do + local v = o[k] + if o[v] then + error( + string.format( + 'The reverse lookup found an existing value for %q while processing key %q', + tostring(v), + tostring(k) + ) + ) + end + o[v] = k + end + return o +end + +--- Index into a table (first argument) via string keys passed as subsequent arguments. +--- Return `nil` if the key does not exist. +--- +--- Examples: +---
+---  luax.tbl_get({ key = { nested_key = true }}, 'key', 'nested_key') == true
+---  luax.tbl_get({ key = {}}, 'key', 'nested_key') == nil
+--- 
+--- +---@param o table Table to index +---@param ... string Optional strings (0 or more, variadic) via which to index the table +--- +---@return any Nested value indexed by key (if it exists), else nil +function luax.tbl_get(o, ...) + local keys = { ... } + if #keys == 0 then + return + end + for i, k in ipairs(keys) do + if type(o[k]) ~= 'table' and next(keys, i) then + return nil + end + o = o[k] + if o == nil then + return + end + end + return o +end + +--- Extends a list-like table with the values of another list-like table. +--- +--- NOTE: This mutates dst! +--- +---@see |luax.tbl_extend()| +--- +---@param dst table List which will be modified and appended to +---@param src table List from which values will be inserted +---@param start number Start index on src. Defaults to 1 +---@param finish number Final index on src. Defaults to `#src` +---@return table dst +function luax.list_extend(dst, src, start, finish) + luax.validate({ + dst = { dst, 't' }, + src = { src, 't' }, + start = { start, 'n', true }, + finish = { finish, 'n', true }, + }) + for i = start or 1, finish or #src do + table.insert(dst, src[i]) + end + return dst +end + +--- Creates a copy of a list-like table such that any nested tables are +--- "unrolled" and appended to the result. +--- +---@see From https://github.com/premake/premake-core/blob/master/src/base/table.lua +--- +---@param t table List-like table +---@return table Flattened copy of the given list-like table +function luax.tbl_flatten(t) + local result = {} + local function _tbl_flatten(_t) + local n = #_t + for i = 1, n do + local v = _t[i] + if type(v) == 'table' then + _tbl_flatten(v) + elseif v then + table.insert(result, v) + end + end + end + _tbl_flatten(t) + return result +end + +--- Tests if a Lua table can be treated as an array. +--- +--- Empty table `{}` is assumed to be an array, unless it was created by +--- |luax.empty_dict()| or returned as a dict-like |API| or Vimscript result, +--- for example from |rpcrequest()| or |luax.fn|. +--- +---@param t table Table +---@return boolean `true` if array-like table, else `false` +function luax.tbl_islist(t) + if type(t) ~= 'table' then + return false + end + + local count = 0 + + for k, _ in pairs(t) do + if type(k) == 'number' then + count = count + 1 + else + return false + end + end + + if count > 0 then + return true + else + -- TODO(bfredl): in the future, we will always be inside nvim + -- then this check can be deleted. + if luax._empty_dict_mt == nil then + return nil + end + return getmetatable(t) ~= luax._empty_dict_mt + end +end + +--- Counts the number of non-nil values in table `t`. +--- +---
+--- luax.tbl_count({ a=1, b=2 }) => 2
+--- luax.tbl_count({ 1, 2 }) => 2
+--- 
+--- +---@see https://github.com/Tieske/Penlight/blob/master/lua/pl/tablex.lua +---@param t table Table +---@return number Number of non-nil values in table +function luax.tbl_count(t) + luax.validate({ t = { t, 't' } }) + + local count = 0 + for _ in pairs(t) do + count = count + 1 + end + return count +end + +--- Creates a copy of a table containing only elements from start to end (inclusive) +--- +---@param list table Table +---@param start number Start range of slice +---@param finish number End range of slice +---@return table Copy of table sliced from start to finish (inclusive) +function luax.list_slice(list, start, finish) + local new_list = {} + for i = start or 1, finish or #list do + new_list[#new_list + 1] = list[i] + end + return new_list +end + +--- Trim whitespace (Lua pattern "%s") from both sides of a string. +--- +---@see https://www.lua.org/pil/20.2.html +---@param s string String to trim +---@return string String with whitespace removed from its beginning and end +function luax.trim(s) + luax.validate({ s = { s, 's' } }) + return s:match('^%s*(.*%S)') or '' +end + +--- Escapes magic chars in |lua-patterns|. +--- +---@see https://github.com/rxi/lume +---@param s string String to escape +---@return string %-escaped pattern string +function luax.pesc(s) + luax.validate({ s = { s, 's' } }) + return s:gsub('[%(%)%.%%%+%-%*%?%[%]%^%$]', '%%%1') +end + +--- Tests if `s` starts with `prefix`. +--- +---@param s string String +---@param prefix string Prefix to match +---@return boolean `true` if `prefix` is a prefix of `s` +function luax.startswith(s, prefix) + luax.validate({ s = { s, 's' }, prefix = { prefix, 's' } }) + return s:sub(1, #prefix) == prefix +end + +--- Tests if `s` ends with `suffix`. +--- +---@param s string String +---@param suffix string Suffix to match +---@return boolean `true` if `suffix` is a suffix of `s` +function luax.endswith(s, suffix) + luax.validate({ s = { s, 's' }, suffix = { suffix, 's' } }) + return #suffix == 0 or s:sub(-#suffix) == suffix +end + +--- Validates a parameter specification (types and values). +--- +--- Usage example: +---
+---  function user.new(name, age, hobbies)
+---    luax.validate{
+---      name={name, 'string'},
+---      age={age, 'number'},
+---      hobbies={hobbies, 'table'},
+---    }
+---    ...
+---  end
+--- 
+--- +--- Examples with explicit argument values (can be run directly): +---
+---  luax.validate{arg1={{'foo'}, 'table'}, arg2={'foo', 'string'}}
+---     => NOP (success)
+---
+---  luax.validate{arg1={1, 'table'}}
+---     => error('arg1: expected table, got number')
+---
+---  luax.validate{arg1={3, function(a) return (a % 2) == 0 end, 'even number'}}
+---     => error('arg1: expected even number, got 3')
+--- 
+--- +--- If multiple types are valid they can be given as a list. +---
+---  luax.validate{arg1={{'foo'}, {'table', 'string'}}, arg2={'foo', {'table', 'string'}}}
+---     => NOP (success)
+---
+---  luax.validate{arg1={1, {'string', table'}}}
+---     => error('arg1: expected string|table, got number')
+---
+--- 
+--- +---@param opt table Names of parameters to validate. Each key is a parameter +--- name; each value is a tuple in one of these forms: +--- 1. (arg_value, type_name, optional) +--- - arg_value: argument value +--- - type_name: string|table type name, one of: ("table", "t", "string", +--- "s", "number", "n", "boolean", "b", "function", "f", "nil", +--- "thread", "userdata") or list of them. +--- - optional: (optional) boolean, if true, `nil` is valid +--- 2. (arg_value, fn, msg) +--- - arg_value: argument value +--- - fn: any function accepting one argument, returns true if and +--- only if the argument is valid. Can optionally return an additional +--- informative error message as the second returned value. +--- - msg: (optional) error string if validation fails +function luax.validate(opt) end -- luacheck: no unused + +do + local type_names = { + ['table'] = 'table', + t = 'table', + ['string'] = 'string', + s = 'string', + ['number'] = 'number', + n = 'number', + ['boolean'] = 'boolean', + b = 'boolean', + ['function'] = 'function', + f = 'function', + ['callable'] = 'callable', + c = 'callable', + ['nil'] = 'nil', + ['thread'] = 'thread', + ['userdata'] = 'userdata', + } + + local function _is_type(val, t) + return type(val) == t or (t == 'callable' and luax.is_callable(val)) + end + + ---@private + local function is_valid(opt) + if type(opt) ~= 'table' then + return false, string.format('opt: expected table, got %s', type(opt)) + end + + for param_name, spec in pairs(opt) do + if type(spec) ~= 'table' then + return false, string.format('opt[%s]: expected table, got %s', param_name, type(spec)) + end + + local val = spec[1] -- Argument value + local types = spec[2] -- Type name, or callable + local optional = (true == spec[3]) + + if type(types) == 'string' then + types = { types } + end + + if luax.is_callable(types) then + -- Check user-provided validation function + local valid, optional_message = types(val) + if not valid then + local error_message = + string.format('%s: expected %s, got %s', param_name, (spec[3] or '?'), tostring(val)) + if optional_message ~= nil then + error_message = error_message .. string.format('. Info: %s', optional_message) + end + + return false, error_message + end + elseif type(types) == 'table' then + local success = false + for i, t in ipairs(types) do + local t_name = type_names[t] + if not t_name then + return false, string.format('invalid type name: %s', t) + end + types[i] = t_name + + if (optional and val == nil) or _is_type(val, t_name) then + success = true + break + end + end + if not success then + return false, + string.format( + '%s: expected %s, got %s', + param_name, + table.concat(types, '|'), + type(val) + ) + end + else + return false, string.format('invalid type name: %s', tostring(types)) + end + end + + return true, nil + end + + function luax.validate(opt) + local ok, err_msg = is_valid(opt) + if not ok then + error(err_msg, 2) + end + end +end +--- Returns true if object `f` can be called as a function. +--- +---@param f any Any object +---@return boolean `true` if `f` is callable, else `false` +function luax.is_callable(f) + if type(f) == 'function' then + return true + end + local m = getmetatable(f) + if m == nil then + return false + end + return type(m.__call) == 'function' +end + +--- Creates a table whose members are automatically created when accessed, if they don't already +--- exist. +--- +--- They mimic defaultdict in python. +--- +--- If {create} is `nil`, this will create a defaulttable whose constructor function is +--- this function, effectively allowing to create nested tables on the fly: +--- +---
+--- local a = luax.defaulttable()
+--- a.b.c = 1
+--- 
+--- +---@param create function|nil The function called to create a missing value. +---@return table Empty table with metamethod +function luax.defaulttable(create) + create = create or luax.defaulttable + return setmetatable({}, { + __index = function(tbl, key) + rawset(tbl, key, create()) + return rawget(tbl, key) + end, + }) +end + +return luax diff --git a/engine/lua/cloe/system.lua b/engine/lua/cloe/system.lua new file mode 100644 index 000000000..0596d33c4 --- /dev/null +++ b/engine/lua/cloe/system.lua @@ -0,0 +1,108 @@ +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- + +local api = require("cloe-engine") +local validate = require("cloe.typecheck").validate + +local system = {} + +--- Run a command with the default system shell and return the output. +--- +--- Discover the shell with: +--- +--- cloe.system('echo "System shell: $0" >&2') +--- +--- Note on stderr: +--- Only stdout is captured. The stderr output from the command +--- is sent straight to stderr of the calling program and not +--- discarded. +--- +--- Capture stderr with: +--- +--- cmd 2>&1 +--- +--- Discard stderr with: +--- +--- cmd 2>/dev/null +--- +--- @param cmd CommandSpec Command to run +--- @return string, number # Combined output, exit code +function system.exec(cmd) + -- FIXME: This is not a very nice API... + if type(cmd) == "string" then + cmd = { + command = cmd, + } + end + if cmd.log_output == nil then cmd.log_output = "on_error" end + if cmd.ignore_failure == nil then cmd.ignore_failure = true end + return api.exec(cmd) +end + +--- Return output from system call or nil on failure. +--- +--- @param cmd CommandSpec +--- @return string|nil +function system.exec_or_nil(cmd) + validate("cloe.report.exec_or_nil(CommandSpec)", cmd) + local out, ec = system.exec(cmd) + if ec ~= 0 then + return nil + end + return out +end + +--- Return system hostname. +--- +--- @return string|nil +function system.get_hostname() + -- FIXME(windows): Does `hostname` have `-f` argument in Windows? + return system.exec_or_nil("hostname -f") +end + +--- Return current username. +--- +--- In a Docker container this probably doesn't provide a lot of value. +--- +--- @return string|nil +function system.get_username() + return system.exec_or_nil("whoami") +end + +--- Return current date and time in RFC 3339 format. +--- +--- Example: 2006-08-14 02:34:56-06:00 +--- +--- @return string +function system.get_datetime() + return tostring(os.date("%Y-%m-%d %H:%M")) +end + +--- Return Git hash of HEAD for the given directory path. +--- +--- @param path string +--- @return string|nil +function system.get_git_hash(path) + validate("system.get_git_hash(string)", path) + return system.exec_or_nil({ + path = "git", + args = {"-C", path, "rev-parse", "HEAD"} + }) +end + +return system diff --git a/engine/lua/cloe/testing.lua b/engine/lua/cloe/testing.lua new file mode 100644 index 000000000..5b61c08da --- /dev/null +++ b/engine/lua/cloe/testing.lua @@ -0,0 +1,822 @@ +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- + +local api = require("cloe-engine") +local types = require("cloe-engine.types") +local actions = require("cloe.actions") +local events = require("cloe.events") + +local LogLevel = types.LogLevel + +local luax = require("cloe.luax") +local validate = require("cloe.typecheck").validate +local inspect = require("inspect").inspect + +--- @class TestFixture +--- @field private _test table +--- @field private _id string +--- @field private _report table +--- @field private _coroutine thread +--- @field private _source string +--- @field private _sync Sync +--- @field private _asserts integer +--- @field private _failures integer +local TestFixture = {} + +--- @class SchedulerInterface +--- +--- Interface to scheduler to support dependency injection and +--- so we can avoid cyclic dependency to/from cloe.engine. +--- +--- @field log fun(level: string, fmt: string, ...: any) +--- @field schedule fun(trigger: Task) +--- @field execute_action fun(action: string|table) + +--- @enum TestStatus +local TestStatus = { + PENDING = "pending", --- Waiting to be scheduled + RUNNING = "running", --- Currently running + ABORTED = "aborted", --- Aborted because of error + STOPPED = "stopped", --- Stopped without explicit pass/fail + FAILED = "failed", --- Stopped with >=1 asserts failed + PASSED = "passed", --- Stopped with all asserts passed +} + +--- @enum ReportOutcome +local ReportOutcome = { + FAILURE = "fail", --- >=1 tests failed + SUCCESS = "pass", --- all tests passed +} + +--- Return a new test fixture for test +--- +--- @param test Test +--- @param scheduler? SchedulerInterface for dependency injection +--- @return TestFixture +--- @nodiscard +function TestFixture.new(test, scheduler) + scheduler = scheduler or {} + + local debinfo = debug.getinfo(test.run) + local source = string.format("%s:%s-%s", debinfo.short_src, debinfo.linedefined, debinfo.lastlinedefined) + + local report = api.state.report + if report["tests"] == nil then + report["tests"] = {} + end + if report.tests[test.id] then + error("test already scheduled with id: " .. test.id) + else + report.tests[test.id] = { + complete = false, + status = TestStatus.PENDING, + info = test.info, + activity = {}, + source = source, + } + end + + return setmetatable({ + _id = test.id, + _test = test, + _report = report.tests[test.id], + _coroutine = coroutine.create(test.run), + _source = source, + _stopped = false, + _asserts = 0, + _asserts_failed = 0, + _asserts_passed = 0, + _log = scheduler.log, + _schedule = scheduler.schedule, + _execute_action = scheduler.execute_action, + }, { + __index = TestFixture, + }) +end + +--- Schedule the test-case. +--- +--- @private +--- @return nil +function TestFixture:schedule_self() + self:_schedule({ + on = self._test.on, + group = self._test.id, + pin = false, + desc = self._test.desc, + enable = true, + source = self._source, + run = function(sync) + self._sync = sync + self:_log("info", "Running test: %s", self._id) + self:_resume(self, sync) + end, + }) +end + +--- Log a message. +--- +--- @private +--- @param level LogLevel +--- @param message string +--- @param ... any +--- @return nil +function TestFixture:_log(level, message, ...) + require("cloe.engine").log(level, message, ...) +end + +--- Schedule resumption of test. +--- +--- @private +--- @param task Task +function TestFixture:_schedule(task) + require("cloe.engine").schedule(task) +end + +--- Execute an action immediately. +--- +--- @private +--- @param action string|table +function TestFixture:_execute_action(action) + require("cloe.engine").execute_action(action) +end + +--- Resume execution of the test after an interuption. +--- +--- This is called at the beginning of the test, and any time it +--- hands control back to the engine to do other work. +--- +--- @private +--- @param ... any +--- @return nil +function TestFixture:_resume(...) + self:_log(LogLevel.DEBUG, "Resuming test %s", self._id) + self:_set_status(TestStatus.RUNNING) + local ok, result = coroutine.resume(self._coroutine, ...) + if not ok then + self:_set_status(TestStatus.ABORTED) + error(string.format("Error with test %s: %s", self._id, result)) + elseif result then + local result_type = type(result) + if result_type == "table" then + -- From self:wait*() methods + self:_schedule(result) + self:_set_status(TestStatus.PENDING) + elseif result_type == "function" then + -- From self:stop() methods + result() + self:_finish() + else + self:_set_status(TestStatus.ABORTED) + error("unknown test yield result: " .. inspect(result)) + end + else + -- From end-of-test-case + self:_finish() + end +end + +--- @private +function TestFixture:_finish() + -- After the test completes, update the report + self._report["asserts"] = { + total = self._asserts, + failed = self._asserts_failed, + passed = self._asserts_passed, + } + if self._asserts_failed > 0 then + self:_set_status(TestStatus.FAILED) + else + self:_set_status(TestStatus.PASSED) + end + self._report.complete = true + self:_terminate() +end + +--- @private +--- @param status TestStatus +function TestFixture:_set_status(status) + self:_log(LogLevel.DEBUG, "[%s] Status -> %s", self._id, status) + self._report.status = status +end + +--- @private +function TestFixture:_terminate() + local report = api.state.report + local tests = 0 + local tests_failed = 0 + for _, test in pairs(report["tests"]) do + if not test.complete then + -- Not all tests complete, let the next fixture do the job + return + end + + tests = tests + 1 + if test.status ~= TestStatus.PASSED then + tests_failed = tests_failed + 1 + end + end + if tests_failed ~= 0 then + report.outcome = ReportOutcome.FAILURE + else + report.outcome = ReportOutcome.SUCCESS + end + + local term = self._test.terminate == nil and true or self._test.terminate + if type(term) == "function" then + term = term(self, self._sync) + end + if term then + self:_log(LogLevel.INFO, "Terminating simulation (disable with terminate=false)...") + if report.outcome == ReportOutcome.FAILURE then + self:_execute_action(actions.fail()) + elseif report.outcome == ReportOutcome.SUCCESS then + self:_execute_action(actions.succeed()) + else + self:_execute_action(actions.stop()) + end + end +end + +--- Add some data to the report. +--- +--- Will also log the message as debug. +--- +--- @param data table +--- @param quiet? boolean +--- @param level? string +--- @return nil +function TestFixture:report_data(data, quiet, level) + data = luax.tbl_extend("error", { time = tostring(self._sync:time()) }, data) + if not quiet then + self:_log(level or LogLevel.DEBUG, "[%s] Report: %s", self._id, inspect(data, { indent = " ", newline = "" })) + end + table.insert(self._report.activity, data) +end + +--- Add a message to the field to the report and log it with the given severity. +--- +--- @param field string Field to assign message to. +--- @param level string Severity to log the message. +--- @param fmt string Format string. +--- @param ... any Arguments to format string. +--- @return nil +function TestFixture:report_with(field, level, fmt, ...) + validate("TestFixture:report_with(string, string, string, [?any]...)", self, field, level, fmt, ...) + local msg = string.format(fmt, ...) + self:_log(level, "[%s] Report %s: %s", self._id, field, msg) + self:report_data({ [field] = msg }, true) +end + +--- Add a message to the report and log it with the given severity. +--- +--- @param level string Severity to log the message. +--- @param fmt string Format string. +--- @param ... any Arguments to format string. +--- @return nil +function TestFixture:report_message(level, fmt, ...) + validate("TestFixture:report_message(string, string, [?any]...)", self, level, fmt, ...) + self:report_with("message", level, fmt, ...) +end + +--- Log a message to report and console in debug severity. +--- +--- @param fmt string +--- @param ... any +--- @return nil +function TestFixture:debugf(fmt, ...) + self:report_message(LogLevel.DEBUG, fmt, ...) +end + +--- Log a message to report and console in info severity. +--- +--- @param fmt string +--- @param ... any +--- @return nil +function TestFixture:printf(fmt, ...) + self:report_message(LogLevel.INFO, fmt, ...) +end + +--- Log a message to report and console in warn severity. +--- +--- @param fmt string +--- @param ... any +--- @return nil +function TestFixture:warnf(fmt, ...) + self:report_message(LogLevel.WARN, fmt, ...) +end + +--- Log a message to report and console in error severity. +--- +--- Note: this does not have an effect on the test results. +--- +--- @param fmt string +--- @param ... any +--- @return nil +function TestFixture:errorf(fmt, ...) + self:report_message(LogLevel.ERROR, fmt, ...) +end + +--- Terminate the execution of the test-case, but not the simulation. +--- +--- @param fmt? string +--- @param ... any +--- @return nil +function TestFixture:stop(fmt, ...) + validate("TestFixture:stop([string], [?any]...)", self, fmt, ...) + if fmt then + self:printf(fmt, ...) + end + coroutine.yield(function() + if self._report.status == TestStatus.PENDING then + self:_set_status(TestStatus.STOPPED) + end + if coroutine.close then + coroutine.close(self._coroutine) + else + self._coroutine = nil + end + end) +end + +--- Fail the test-case and stop the simulation. +--- +--- Note: It is best practice to use expect and assert methods and allow the +--- test-case fixture to determine failure/success itself. +--- +--- @param fmt string optional message +--- @param ... any +--- @return nil +function TestFixture:fail(fmt, ...) + validate("TestFixture:fail([string], [?any]...)", self, fmt, ...) + if fmt then + self:errorf(fmt, ...) + end + self:do_action(actions.fail()) + self:stop() +end + +--- Succeed the test-case and stop the simulation. +--- +--- Note: It is best practice to use expect and assert methods and allow the +--- test-case fixture to determine failure/success itself. +--- +--- @param fmt? string optional message +--- @param ... any +--- @return nil +function TestFixture:succeed(fmt, ...) + validate("TestFixture:succeed([string], [?any]...)", self, fmt, ...) + if fmt then + self:printf(fmt, ...) + end + self:do_action(actions.succeed()) + self:stop() +end + +--- Wait simulated time given in duration units, e.g. "1.5s". +--- +--- This will yield execution of the test-case back to the simulation +--- until the duration has elapsed. +--- +--- @param duration string +--- @return nil +function TestFixture:wait_duration(duration) + validate("TestFixture:wait_duration(string)", self, duration) + self:debugf("wait for duration: %s", duration) + coroutine.yield({ + on = "next=" .. types.Duration.new(duration):s(), + group = self._id, + run = function(sync) + return self:_resume() + end, + }) +end + +--- Wait until the function supplied returns true, then resume. +--- +--- This will yield execution of the test-case back to the simulation +--- until the function, which is run once every cycle, returns true. +--- +--- @param condition fun(sync: Sync):boolean +--- @param timeout? Duration|string +--- @return nil +function TestFixture:wait_until(condition, timeout) + validate("TestFixture:wait_until(function, [string|userdata|nil])", self, condition, timeout) + if type(timeout) == "string" then + timeout = types.Duration.new(timeout) + end + if timeout then + timeout = self._sync:time() + timeout + self:debugf("wait until condition with timeout %s: %s", timeout, condition) + else + self:debugf("wait until condition: %s", condition) + end + coroutine.yield({ + on = events.loop(), + group = self._id, + pin = true, + run = function(sync) + if condition(sync) then + self:_resume(true) + return false + elseif timeout and sync:time() > timeout then + self:warnf("condition timed out after %s", timeout) + self:_resume(false) + return false + end + end, + }) +end + +--- @class WaitSpec +--- @field condition? function(Sync):boolean +--- @field timeout? Duration|string +--- @field fail_on_timeout? boolean + +--- @param spec WaitSpec +function TestFixture:wait_for(spec) + validate("TestFixture:wait_for(table)", spec) + if not spec.condition and not spec.timeout then + error("TestFixture:wait_for(): require one of condition or timeout to be set") + end + error("not implemented") +end + +function TestFixture:do_action(action) + validate("TestFixture:do_action(string|table)", self, action) + self:report_with("action", "debug", "%s", action) + self:_execute_action(action) +end + +--- @class Operator +--- @field fn fun(left: any, right: any):boolean +--- @field repr string + +--- Expect an operation with op(left, right) == true. +--- +--- @private +--- @param op Operator operator object with comparison function and string representation +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional message format string +--- @param ... any arguments to string.format +--- @return boolean # result of expression +function TestFixture:_expect_op(op, left, right, fmt, ...) + validate("TestFixture:_expect_op(table, any, any, [string], [?any]...)", self, op, left, right, fmt, ...) + self._asserts = self._asserts + 1 + local msg = nil + if fmt then + msg = string.format(fmt, ...) + end + local report = { + assert = string.format("%s %s %s", left, op.repr, right), + left = inspect(left, { newline = " ", indent = "" }), + right = inspect(right, { newline = " ", indent = "" }), + value = op.fn(left, right), + message = msg, + } + self:report_data(report, true) + if report.value then + self._asserts_passed = self._asserts_passed + 1 + self:_log(LogLevel.INFO, "[%s] Check %s: %s (=%s)", self._id, msg or "ok", report.assert, report.value) + else + self._asserts_failed = self._asserts_failed + 1 + self:_log(LogLevel.ERROR, "[%s] !! Check %s: %s (=%s)", self._id, msg or "failed", report.assert, report.value) + end + return report.value +end + +--- Expect that the first argument is truthy. +--- +--- On failure, execution continues, but the test-case is marked as failed. +--- +--- The message should describe the expectation: +--- +--- z:expect(var == 1, "var should == 1, is %s", var) +--- +--- You should check if a more specific assertion is available first though, +--- as these provide better messages. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param value any expression or value that should be truthy +--- @param fmt? string human-readable expectation of result +--- @param ... any arguments to format string +--- @return any value # input value / expression result +function TestFixture:expect(value, fmt, ...) + return self:_expect_op({ + fn = function(a) + return a + end, + repr = "is", + }, value, "truthy", fmt, ...) +end + +--- Assert that the first argument is truthy. +--- +--- On failure, execution is stopped and the test-case is marked as failed. +--- +--- The message should describe the expectation: +--- +--- z:assert(var == 1, "var should == 1, is %s", var) +--- +--- You should check if a more specific assertion is available first though, +--- as these provide better messages. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param value any expression or value that should be truthy +--- @param fmt? string human-readable expectation of result +--- @param ... any arguments to format string +--- @return any value # input value / expression result +function TestFixture:assert(value, fmt, ...) + if not self:expect(value, fmt, ...) then + self:fail("[%s] test assertion failed", self._id) + end + return value +end + +local Operator = { + eq = { fn = function(a, b) return a == b end, repr = "==", }, --- Equal to + ne = { fn = function(a, b) return a ~= b end, repr = "~=", }, --- Not equal to + lt = { fn = function(a, b) return a < b end, repr = "<", }, --- Less than + le = { fn = function(a, b) return a <= b end, repr = "<=", }, --- Less than or equal to + gt = { fn = function(a, b) return a > b end, repr = ">", }, --- Greater than + ge = { fn = function(a, b) return a >= b end, repr = ">=", }, --- Greater than or equal to +} + +--- Expect that `left == right` or mark test-case as failed. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison +function TestFixture:expect_eq(left, right, fmt, ...) + return self:_expect_op(Operator.eq, left, right, fmt, ...) +end + +--- Expect that `left ~= right` or mark test-case as failed. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison +function TestFixture:expect_ne(left, right, fmt, ...) + return self:_expect_op(Operator.ne, left, right, fmt, ...) +end + +--- Expect that `left < right` or mark test-case as failed. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison +function TestFixture:expect_lt(left, right, fmt, ...) + return self:_expect_op(Operator.lt, left, right, fmt, ...) +end + +--- Expect that `left <= right` or mark test-case as failed. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison +function TestFixture:expect_le(left, right, fmt, ...) + return self:_expect_op(Operator.le, left, right, fmt, ...) +end + +--- Expect that `left > right` or mark test-case as failed. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison +function TestFixture:expect_gt(left, right, fmt, ...) + return self:_expect_op(Operator.gt, left, right, fmt, ...) +end + +--- Expect that `left >= right` or mark test-case as failed. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison +function TestFixture:expect_ge(left, right, fmt, ...) + return self:_expect_op(Operator.ge, left, right, fmt, ...) +end + +--- Assert that `left == right` or fail simulation. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison, if true +function TestFixture:assert_eq(left, right, fmt, ...) + if not self:_expect_op(Operator.eq, left, right, fmt, ...) then + self:fail("[%s] test assertion failed", self._id) + end +end + +--- Assert that `left ~= right` or fail simulation. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison, if true +function TestFixture:assert_ne(left, right, fmt, ...) + if not self:_expect_op(Operator.ne, left, right, fmt, ...) then + self:fail("[%s] test assertion failed", self._id) + end +end + +--- Assert that `left < right` or fail simulation. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison, if true +function TestFixture:assert_lt(left, right, fmt, ...) + if not self:_expect_op(Operator.lt, left, right, fmt, ...) then + self:fail("[%s] test assertion failed", self._id) + end +end + +--- Assert that `left <= right` or fail simulation. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison, if true +function TestFixture:assert_le(left, right, fmt, ...) + if not self:_expect_op(Operator.le, left, right, fmt, ...) then + self:fail("[%s] test assertion failed", self._id) + end +end + +--- Assert that `left > right` or fail simulation. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison, if true +function TestFixture:assert_gt(left, right, fmt, ...) + if not self:_expect_op(Operator.gt, left, right, fmt, ...) then + self:fail("[%s] test assertion failed", self._id) + end +end + +--- Assert that `left >= right` or fail simulation. +--- +--- See: +--- - [string.format](https://www.lua.org/manual/5.3/manual.html#pdf-string.format) +--- for help on formatting +--- +--- @param left any left-hand operand +--- @param right any right-hand operand +--- @param fmt? string optional human-readable string describing expectation +--- @param ... any optional arguments to string.format +--- @return any # result of comparison, if true +function TestFixture:assert_ge(left, right, fmt, ...) + if not self:_expect_op(Operator.ge, left, right, fmt, ...) then + self:fail("[%s] test assertion failed", self._id) + end +end + +--- Return number of leading tabs in string. +--- +--- @param str string +--- @return number +--- @nodiscard +local function count_leading_tabs(str) + local count = 0 + for i = 1, #str do + local char = string.sub(str, i, i) + if char == "\t" then + count = count + 1 + else + break + end + end + return count +end + +--- Start a description block based on the Lust framework. +--- +--- The Lust framework provides Behavior-Driven-Development (BDD) style +--- test tooling. See their website for more information: +--- +--- https://github.com/bjornbytes/lust +--- +--- Warning: In unfortunate circumstances, using this method may (in its +--- current implementation) result in error messages and/or other output +--- from Lua that uses the print() statement being suppressed. +--- +--- @deprecated EXPERIMENTAL +--- @param name string Description of subject +--- @param fn fun() Function to scope test execution +--- @return nil +function TestFixture:describe(name, fn) + local lust = require("lust") + lust.nocolor() + + local lust_describe_activity = { name = "", evaluation = {} } + + -- Lust uses print(), so we hijack the function temporarily to capture + -- its output. + -- + -- NOTE: This also means that if there is some kind of error within + -- such a describe block, we may not re-install the original print + -- function and all further output may be suppressed! + local oldprint = _G.print + _ENV.print = function(msg) + local tab_count = count_leading_tabs(msg) + msg = luax.trim(msg) -- remove leading tab + if tab_count == 0 then + lust_describe_activity["name"] = msg + elseif tab_count > 0 then + table.insert(lust_describe_activity.evaluation, msg) + end + end + lust.describe(name, fn) + _ENV.print = oldprint + + self:report_data(lust_describe_activity) +end + +return { + TestFixture = TestFixture, +} diff --git a/engine/lua/cloe/typecheck.lua b/engine/lua/cloe/typecheck.lua new file mode 100644 index 000000000..e6f8f2e5a --- /dev/null +++ b/engine/lua/cloe/typecheck.lua @@ -0,0 +1,51 @@ +-- +-- Copyright 2023 Robert Bosch GmbH +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- + +local m = {} + +local skip_typechecks = false + +function m.disable() + skip_typechecks = true +end + +function m.validate(format, ...) + if skip_typechecks then + return + end + local fn = require("typecheck").argscheck(format, function() end) + fn(...) +end + +--- Validate the shape (from tableshape) of a table or type. +--- +--- @param signature string function signature for error message +--- @param shape any shape validator +--- @param value any value to validate +--- @return nil # raises an error (level 3) if invalid +function m.validate_shape(signature, shape, value) + if skip_typechecks then + return + end + local ok, msg = shape:check_value(value) + if not ok then + error(signature .. ": " .. msg, 3) + end +end + +return m diff --git a/engine/lua/inspect.lua b/engine/lua/inspect.lua new file mode 100644 index 000000000..c232f6959 --- /dev/null +++ b/engine/lua/inspect.lua @@ -0,0 +1,379 @@ +local inspect = { + _VERSION = 'inspect.lua 3.1.0', + _URL = 'http://github.com/kikito/inspect.lua', + _DESCRIPTION = 'human-readable representations of tables', + _LICENSE = [[ + MIT LICENSE + + Copyright (c) 2013 Enrique García Cota + + Permission is hereby granted, free of charge, to any person obtaining a + copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ]], +} + +inspect.KEY = setmetatable({}, { + __tostring = function() + return 'inspect.KEY' + end, +}) +inspect.METATABLE = setmetatable({}, { + __tostring = function() + return 'inspect.METATABLE' + end, +}) + +local tostring = tostring +local rep = string.rep +local match = string.match +local char = string.char +local gsub = string.gsub +local fmt = string.format + +local function rawpairs(t) + return next, t, nil +end + +-- Apostrophizes the string if it has quotes, but not aphostrophes +-- Otherwise, it returns a regular quoted string +local function smartQuote(str) + if match(str, '"') and not match(str, "'") then + return "'" .. str .. "'" + end + return '"' .. gsub(str, '"', '\\"') .. '"' +end + +-- \a => '\\a', \0 => '\\0', 31 => '\31' +local shortControlCharEscapes = { + ['\a'] = '\\a', + ['\b'] = '\\b', + ['\f'] = '\\f', + ['\n'] = '\\n', + ['\r'] = '\\r', + ['\t'] = '\\t', + ['\v'] = '\\v', + ['\127'] = '\\127', +} +local longControlCharEscapes = { ['\127'] = '\127' } +for i = 0, 31 do + local ch = char(i) + if not shortControlCharEscapes[ch] then + shortControlCharEscapes[ch] = '\\' .. i + longControlCharEscapes[ch] = fmt('\\%03d', i) + end +end + +local function escape(str) + return ( + gsub( + gsub(gsub(str, '\\', '\\\\'), '(%c)%f[0-9]', longControlCharEscapes), + '%c', + shortControlCharEscapes + ) + ) +end + +-- List of lua keywords +local luaKeywords = { + ['and'] = true, + ['break'] = true, + ['do'] = true, + ['else'] = true, + ['elseif'] = true, + ['end'] = true, + ['false'] = true, + ['for'] = true, + ['function'] = true, + ['goto'] = true, + ['if'] = true, + ['in'] = true, + ['local'] = true, + ['nil'] = true, + ['not'] = true, + ['or'] = true, + ['repeat'] = true, + ['return'] = true, + ['then'] = true, + ['true'] = true, + ['until'] = true, + ['while'] = true, +} + +local function isIdentifier(str) + return type(str) == 'string' + -- identifier must start with a letter and underscore, and be followed by letters, numbers, and underscores + and not not str:match('^[_%a][_%a%d]*$') + -- lua keywords are not valid identifiers + and not luaKeywords[str] +end + +local flr = math.floor +local function isSequenceKey(k, sequenceLength) + return type(k) == 'number' and flr(k) == k and 1 <= k and k <= sequenceLength +end + +local defaultTypeOrders = { + ['number'] = 1, + ['boolean'] = 2, + ['string'] = 3, + ['table'] = 4, + ['function'] = 5, + ['userdata'] = 6, + ['thread'] = 7, +} + +local function sortKeys(a, b) + local ta, tb = type(a), type(b) + + -- strings and numbers are sorted numerically/alphabetically + if ta == tb and (ta == 'string' or ta == 'number') then + return a < b + end + + local dta = defaultTypeOrders[ta] or 100 + local dtb = defaultTypeOrders[tb] or 100 + -- Two default types are compared according to the defaultTypeOrders table + + -- custom types are sorted out alphabetically + return dta == dtb and ta < tb or dta < dtb +end + +local function getKeys(t) + local seqLen = 1 + while rawget(t, seqLen) ~= nil do + seqLen = seqLen + 1 + end + seqLen = seqLen - 1 + + local keys, keysLen = {}, 0 + for k in rawpairs(t) do + if not isSequenceKey(k, seqLen) then + keysLen = keysLen + 1 + keys[keysLen] = k + end + end + table.sort(keys, sortKeys) + return keys, keysLen, seqLen +end + +local function countCycles(x, cycles) + if type(x) == 'table' then + if cycles[x] then + cycles[x] = cycles[x] + 1 + else + cycles[x] = 1 + for k, v in rawpairs(x) do + countCycles(k, cycles) + countCycles(v, cycles) + end + countCycles(getmetatable(x), cycles) + end + end +end + +local function makePath(path, a, b) + local newPath = {} + local len = #path + for i = 1, len do + newPath[i] = path[i] + end + + newPath[len + 1] = a + newPath[len + 2] = b + + return newPath +end + +local function processRecursive(process, item, path, visited) + if item == nil then + return nil + end + if visited[item] then + return visited[item] + end + + local processed = process(item, path) + if type(processed) == 'table' then + local processedCopy = {} + visited[item] = processedCopy + local processedKey + + for k, v in rawpairs(processed) do + processedKey = processRecursive(process, k, makePath(path, k, inspect.KEY), visited) + if processedKey ~= nil then + processedCopy[processedKey] = + processRecursive(process, v, makePath(path, processedKey), visited) + end + end + + local mt = + processRecursive(process, getmetatable(processed), makePath(path, inspect.METATABLE), visited) + if type(mt) ~= 'table' then + mt = nil + end + setmetatable(processedCopy, mt) + processed = processedCopy + end + return processed +end + +local function puts(buf, str) + buf.n = buf.n + 1 + buf[buf.n] = str +end + +local Inspector = {} + +local Inspector_mt = { __index = Inspector } + +local function tabify(inspector) + puts(inspector.buf, inspector.newline .. rep(inspector.indent, inspector.level)) +end + +function Inspector:getId(v) + local id = self.ids[v] + local ids = self.ids + if not id then + local tv = type(v) + id = (ids[tv] or 0) + 1 + ids[v], ids[tv] = id, id + end + return tostring(id) +end + +function Inspector:putValue(v) + local buf = self.buf + local tv = type(v) + if tv == 'string' then + puts(buf, smartQuote(escape(v))) + elseif + tv == 'number' + or tv == 'boolean' + or tv == 'nil' + or tv == 'cdata' + or tv == 'ctype' + or (vim and v == vim.NIL) + then + puts(buf, tostring(v)) + elseif tv == 'table' and not self.ids[v] then + local t = v + + if t == inspect.KEY or t == inspect.METATABLE then + puts(buf, tostring(t)) + elseif self.level >= self.depth then + puts(buf, '{...}') + else + if self.cycles[t] > 1 then + puts(buf, fmt('<%d>', self:getId(t))) + end + + local keys, keysLen, seqLen = getKeys(t) + local mt = getmetatable(t) + + if vim and seqLen == 0 and keysLen == 0 and mt == vim._empty_dict_mt then + puts(buf, tostring(t)) + return + end + + puts(buf, '{') + self.level = self.level + 1 + + for i = 1, seqLen + keysLen do + if i > 1 then + puts(buf, ',') + end + if i <= seqLen then + puts(buf, ' ') + self:putValue(t[i]) + else + local k = keys[i - seqLen] + tabify(self) + if isIdentifier(k) then + puts(buf, k) + else + puts(buf, '[') + self:putValue(k) + puts(buf, ']') + end + puts(buf, ' = ') + self:putValue(t[k]) + end + end + + if type(mt) == 'table' then + if seqLen + keysLen > 0 then + puts(buf, ',') + end + tabify(self) + puts(buf, ' = ') + self:putValue(mt) + end + + self.level = self.level - 1 + + if keysLen > 0 or type(mt) == 'table' then + tabify(self) + elseif seqLen > 0 then + puts(buf, ' ') + end + + puts(buf, '}') + end + else + puts(buf, fmt('<%s %d>', tv, self:getId(v))) + end +end + +function inspect.inspect(root, options) + options = options or {} + + local depth = options.depth or math.huge + local newline = options.newline or '\n' + local indent = options.indent or ' ' + local process = options.process + + if process then + root = processRecursive(process, root, {}, {}) + end + + local cycles = {} + countCycles(root, cycles) + + local inspector = setmetatable({ + buf = { n = 0 }, + ids = {}, + cycles = cycles, + depth = depth, + level = 0, + newline = newline, + indent = indent, + }, Inspector_mt) + + inspector:putValue(root) + + return table.concat(inspector.buf) +end + +setmetatable(inspect, { + __call = function(_, root, options) + return inspect.inspect(root, options) + end, +}) + +return inspect diff --git a/engine/lua/json.lua b/engine/lua/json.lua new file mode 100644 index 000000000..c98b38d72 --- /dev/null +++ b/engine/lua/json.lua @@ -0,0 +1,1869 @@ +-- -*- coding: utf-8 -*- +-- +-- Simple JSON encoding and decoding in pure Lua. +-- +-- Copyright 2010-2017 Jeffrey Friedl +-- http://regex.info/blog/ +-- Latest version: http://regex.info/blog/lua/json +-- +-- This code is released under a Creative Commons CC-BY "Attribution" License: +-- http://creativecommons.org/licenses/by/3.0/deed.en_US +-- +-- It can be used for any purpose so long as: +-- 1) the copyright notice above is maintained +-- 2) the web-page links above are maintained +-- 3) the 'AUTHOR_NOTE' string below is maintained +-- +local VERSION = '20211016.28' -- version history at end of file +local AUTHOR_NOTE = "-[ JSON.lua package by Jeffrey Friedl (http://regex.info/blog/lua/json) version 20211016.28 ]-" + +-- +-- The 'AUTHOR_NOTE' variable exists so that information about the source +-- of the package is maintained even in compiled versions. It's also +-- included in OBJDEF below mostly to quiet warnings about unused variables. +-- +local OBJDEF = { + VERSION = VERSION, + AUTHOR_NOTE = AUTHOR_NOTE, +} + + +-- +-- Simple JSON encoding and decoding in pure Lua. +-- JSON definition: http://www.json.org/ +-- +-- +-- JSON = assert(loadfile "JSON.lua")() -- one-time load of the routines +-- +-- local lua_value = JSON:decode(raw_json_text) +-- +-- local raw_json_text = JSON:encode(lua_table_or_value) +-- local pretty_json_text = JSON:encode_pretty(lua_table_or_value) -- "pretty printed" version for human readability +-- +-- +-- +-- DECODING (from a JSON string to a Lua table) +-- +-- +-- JSON = assert(loadfile "JSON.lua")() -- one-time load of the routines +-- +-- local lua_value = JSON:decode(raw_json_text) +-- +-- If the JSON text is for an object or an array, e.g. +-- { "what": "books", "count": 3 } +-- or +-- [ "Larry", "Curly", "Moe" ] +-- +-- the result is a Lua table, e.g. +-- { what = "books", count = 3 } +-- or +-- { "Larry", "Curly", "Moe" } +-- +-- +-- The encode and decode routines accept an optional second argument, +-- "etc", which is not used during encoding or decoding, but upon error +-- is passed along to error handlers. It can be of any type (including nil). +-- +-- +-- +-- ERROR HANDLING DURING DECODE +-- +-- With most errors during decoding, this code calls +-- +-- JSON:onDecodeError(message, text, location, etc) +-- +-- with a message about the error, and if known, the JSON text being +-- parsed and the byte count where the problem was discovered. You can +-- replace the default JSON:onDecodeError() with your own function. +-- +-- The default onDecodeError() merely augments the message with data +-- about the text and the location (and, an 'etc' argument had been +-- provided to decode(), its value is tacked onto the message as well), +-- and then calls JSON.assert(), which itself defaults to Lua's built-in +-- assert(), and can also be overridden. +-- +-- For example, in an Adobe Lightroom plugin, you might use something like +-- +-- function JSON:onDecodeError(message, text, location, etc) +-- LrErrors.throwUserError("Internal Error: invalid JSON data") +-- end +-- +-- or even just +-- +-- function JSON.assert(message) +-- LrErrors.throwUserError("Internal Error: " .. message) +-- end +-- +-- If JSON:decode() is passed a nil, this is called instead: +-- +-- JSON:onDecodeOfNilError(message, nil, nil, etc) +-- +-- and if JSON:decode() is passed HTML instead of JSON, this is called: +-- +-- JSON:onDecodeOfHTMLError(message, text, nil, etc) +-- +-- The use of the 'etc' argument allows stronger coordination between +-- decoding and error reporting, especially when you provide your own +-- error-handling routines. Continuing with the the Adobe Lightroom +-- plugin example: +-- +-- function JSON:onDecodeError(message, text, location, etc) +-- local note = "Internal Error: invalid JSON data" +-- if type(etc) = 'table' and etc.photo then +-- note = note .. " while processing for " .. etc.photo:getFormattedMetadata('fileName') +-- end +-- LrErrors.throwUserError(note) +-- end +-- +-- : +-- : +-- +-- for i, photo in ipairs(photosToProcess) do +-- : +-- : +-- local data = JSON:decode(someJsonText, { photo = photo }) +-- : +-- : +-- end +-- +-- +-- +-- If the JSON text passed to decode() has trailing garbage (e.g. as with the JSON "[123]xyzzy"), +-- the method +-- +-- JSON:onTrailingGarbage(json_text, location, parsed_value, etc) +-- +-- is invoked, where: +-- +-- 'json_text' is the original JSON text being parsed, +-- 'location' is the count of bytes into 'json_text' where the garbage starts (6 in the example), +-- 'parsed_value' is the Lua result of what was successfully parsed ({123} in the example), +-- 'etc' is as above. +-- +-- If JSON:onTrailingGarbage() does not abort, it should return the value decode() should return, +-- or nil + an error message. +-- +-- local new_value, error_message = JSON:onTrailingGarbage() +-- +-- The default JSON:onTrailingGarbage() simply invokes JSON:onDecodeError("trailing garbage"...), +-- but you can have this package ignore trailing garbage via +-- +-- function JSON:onTrailingGarbage(json_text, location, parsed_value, etc) +-- return parsed_value +-- end +-- +-- +-- DECODING AND STRICT TYPES +-- +-- Because both JSON objects and JSON arrays are converted to Lua tables, +-- it's not normally possible to tell which original JSON type a +-- particular Lua table was derived from, or guarantee decode-encode +-- round-trip equivalency. +-- +-- However, if you enable strictTypes, e.g. +-- +-- JSON = assert(loadfile "JSON.lua")() --load the routines +-- JSON.strictTypes = true +-- +-- then the Lua table resulting from the decoding of a JSON object or +-- JSON array is marked via Lua metatable, so that when re-encoded with +-- JSON:encode() it ends up as the appropriate JSON type. +-- +-- (This is not the default because other routines may not work well with +-- tables that have a metatable set, for example, Lightroom API calls.) +-- +-- +-- DECODING AND STRICT PARSING +-- +-- If strictParsing is true in your JSON object, or if you set strictParsing as a decode option, +-- some kinds of technically-invalid JSON that would normally be accepted are rejected with an error. +-- +-- For example, passing in an empty string +-- +-- JSON:decode("") +-- +-- normally succeeds with a return value of nil, but +-- +-- JSON:decode("", nil, { strictParsing = true }) +-- +-- results in an error being raised (onDecodeError is called). +-- +-- JSON.strictParsing = true +-- JSON:decode("") +-- +-- achieves the same thing. +-- +-- +-- +-- ENCODING (from a lua table to a JSON string) +-- +-- JSON = assert(loadfile "JSON.lua")() -- one-time load of the routines +-- +-- local raw_json_text = JSON:encode(lua_table_or_value) +-- local pretty_json_text = JSON:encode_pretty(lua_table_or_value) -- "pretty printed" version for human readability +-- local custom_pretty = JSON:encode(lua_table_or_value, etc, { pretty = true, indent = "| ", align_keys = false }) +-- +-- On error during encoding, this code calls: +-- +-- JSON:onEncodeError(message, etc) +-- +-- which you can override in your local JSON object. Also see "HANDLING UNSUPPORTED VALUE TYPES" below. +-- +-- The 'etc' in the error call is the second argument to encode() and encode_pretty(), or nil if it wasn't provided. +-- +-- +-- +-- +-- ENCODING OPTIONS +-- +-- An optional third argument, a table of options, can be provided to encode(). +-- +-- encode_options = { +-- -- options for making "pretty" human-readable JSON (see "PRETTY-PRINTING" below) +-- pretty = true, -- turn pretty formatting on +-- indent = " ", -- use this indent for each level of an array/object +-- align_keys = false, -- if true, align the keys in a way that sounds like it should be nice, but is actually ugly +-- array_newline = false, -- if true, array elements become one to a line rather than inline +-- +-- -- other output-related options +-- null = "\0", -- see "ENCODING JSON NULL VALUES" below +-- stringsAreUtf8 = false, -- see "HANDLING UNICODE LINE AND PARAGRAPH SEPARATORS FOR JAVA" below +-- } +-- +-- json_string = JSON:encode(mytable, etc, encode_options) +-- +-- +-- +-- For reference, the defaults are: +-- +-- pretty = false +-- null = nil, +-- stringsAreUtf8 = false, +-- +-- +-- +-- PRETTY-PRINTING +-- +-- Enabling the 'pretty' encode option helps generate human-readable JSON. +-- +-- pretty = JSON:encode(val, etc, { +-- pretty = true, +-- indent = " ", +-- align_keys = false, +-- }) +-- +-- encode_pretty() is also provided: it's identical to encode() except +-- that encode_pretty() provides a default options table if none given in the call: +-- +-- { pretty = true, indent = " ", align_keys = false, array_newline = false } +-- +-- For example, if +-- +-- JSON:encode(data) +-- +-- produces: +-- +-- {"city":"Kyoto","climate":{"avg_temp":16,"humidity":"high","snowfall":"minimal"},"country":"Japan","wards":11} +-- +-- then +-- +-- JSON:encode_pretty(data) +-- +-- produces: +-- +-- { +-- "city": "Kyoto", +-- "climate": { +-- "avg_temp": 16, +-- "humidity": "high", +-- "snowfall": "minimal" +-- }, +-- "country": "Japan", +-- "wards": 11 +-- } +-- +-- The following lines all return identical strings: +-- JSON:encode_pretty(data) +-- JSON:encode_pretty(data, nil, { pretty = true, indent = " ", align_keys = false, array_newline = false}) +-- JSON:encode_pretty(data, nil, { pretty = true, indent = " " }) +-- JSON:encode (data, nil, { pretty = true, indent = " " }) +-- +-- An example of setting your own indent string: +-- +-- JSON:encode_pretty(data, nil, { pretty = true, indent = "| " }) +-- +-- produces: +-- +-- { +-- | "city": "Kyoto", +-- | "climate": { +-- | | "avg_temp": 16, +-- | | "humidity": "high", +-- | | "snowfall": "minimal" +-- | }, +-- | "country": "Japan", +-- | "wards": 11 +-- } +-- +-- An example of setting align_keys to true: +-- +-- JSON:encode_pretty(data, nil, { pretty = true, indent = " ", align_keys = true }) +-- +-- produces: +-- +-- { +-- "city": "Kyoto", +-- "climate": { +-- "avg_temp": 16, +-- "humidity": "high", +-- "snowfall": "minimal" +-- }, +-- "country": "Japan", +-- "wards": 11 +-- } +-- +-- which I must admit is kinda ugly, sorry. This was the default for +-- encode_pretty() prior to version 20141223.14. +-- +-- +-- HANDLING UNICODE LINE AND PARAGRAPH SEPARATORS FOR JAVA +-- +-- If the 'stringsAreUtf8' encode option is set to true, consider Lua strings not as a sequence of bytes, +-- but as a sequence of UTF-8 characters. +-- +-- Currently, the only practical effect of setting this option is that Unicode LINE and PARAGRAPH +-- separators, if found in a string, are encoded with a JSON escape instead of being dumped as is. +-- The JSON is valid either way, but encoding this way, apparently, allows the resulting JSON +-- to also be valid Java. +-- +-- AMBIGUOUS SITUATIONS DURING THE ENCODING +-- +-- During the encode, if a Lua table being encoded contains both string +-- and numeric keys, it fits neither JSON's idea of an object, nor its +-- idea of an array. To get around this, when any string key exists (or +-- when non-positive numeric keys exist), numeric keys are converted to +-- strings. +-- +-- For example, +-- JSON:encode({ "one", "two", "three", SOMESTRING = "some string" })) +-- produces the JSON object +-- {"1":"one","2":"two","3":"three","SOMESTRING":"some string"} +-- +-- To prohibit this conversion and instead make it an error condition, set +-- JSON.noKeyConversion = true +-- +-- +-- ENCODING JSON NULL VALUES +-- +-- Lua tables completely omit keys whose value is nil, so without special handling there's +-- no way to represent JSON object's null value in a Lua table. For example +-- JSON:encode({ username = "admin", password = nil }) +-- +-- produces: +-- +-- {"username":"admin"} +-- +-- In order to actually produce +-- +-- {"username":"admin", "password":null} +-- + +-- one can include a string value for a "null" field in the options table passed to encode().... +-- any Lua table entry with that value becomes null in the JSON output: +-- +-- JSON:encode({ username = "admin", password = "xyzzy" }, -- First arg is the Lua table to encode as JSON. +-- nil, -- Second arg is the 'etc' value, ignored here +-- { null = "xyzzy" }) -- Third arg is th options table +-- +-- produces: +-- +-- {"username":"admin", "password":null} +-- +-- Just be sure to use a string that is otherwise unlikely to appear in your data. +-- The string "\0" (a string with one null byte) may well be appropriate for many applications. +-- +-- The "null" options also applies to Lua tables that become JSON arrays. +-- JSON:encode({ "one", "two", nil, nil }) +-- +-- produces +-- +-- ["one","two"] +-- +-- while +-- +-- NullPlaceholder = "\0" +-- encode_options = { null = NullPlaceholder } +-- JSON:encode({ "one", "two", NullPlaceholder, NullPlaceholder}, nil, encode_options) +-- produces +-- +-- ["one","two",null,null] +-- +-- +-- +-- HANDLING LARGE AND/OR PRECISE NUMBERS +-- +-- +-- Without special handling, numbers in JSON can lose precision in Lua. +-- For example: +-- +-- T = JSON:decode('{ "small":12345, "big":12345678901234567890123456789, "precise":9876.67890123456789012345 }') +-- +-- print("small: ", type(T.small), T.small) +-- print("big: ", type(T.big), T.big) +-- print("precise: ", type(T.precise), T.precise) +-- +-- produces +-- +-- small: number 12345 +-- big: number 1.2345678901235e+28 +-- precise: number 9876.6789012346 +-- +-- Precision is lost with both 'big' and 'precise'. +-- +-- This package offers ways to try to handle this better (for some definitions of "better")... +-- +-- The most precise method is by setting the global: +-- +-- JSON.decodeNumbersAsObjects = true +-- +-- When this is set, numeric JSON data is encoded into Lua in a form that preserves the exact +-- JSON numeric presentation when re-encoded back out to JSON, or accessed in Lua as a string. +-- +-- This is done by encoding the numeric data with a Lua table/metatable that returns +-- the possibly-imprecise numeric form when accessed numerically, but the original precise +-- representation when accessed as a string. +-- +-- Consider the example above, with this option turned on: +-- +-- JSON.decodeNumbersAsObjects = true +-- +-- T = JSON:decode('{ "small":12345, "big":12345678901234567890123456789, "precise":9876.67890123456789012345 }') +-- +-- print("small: ", type(T.small), T.small) +-- print("big: ", type(T.big), T.big) +-- print("precise: ", type(T.precise), T.precise) +-- +-- This now produces: +-- +-- small: table 12345 +-- big: table 12345678901234567890123456789 +-- precise: table 9876.67890123456789012345 +-- +-- However, within Lua you can still use the values (e.g. T.precise in the example above) in numeric +-- contexts. In such cases you'll get the possibly-imprecise numeric version, but in string contexts +-- and when the data finds its way to this package's encode() function, the original full-precision +-- representation is used. +-- +-- You can force access to the string or numeric version via +-- JSON:forceString() +-- JSON:forceNumber() +-- For example, +-- local probably_okay = JSON:forceNumber(T.small) -- 'probably_okay' is a number +-- +-- Code the inspects the JSON-turned-Lua data using type() can run into troubles because what used to +-- be a number can now be a table (e.g. as the small/big/precise example above shows). Update these +-- situations to use JSON:isNumber(item), which returns nil if the item is neither a number nor one +-- of these number objects. If it is either, it returns the number itself. For completeness there's +-- also JSON:isString(item). +-- +-- If you want to try to avoid the hassles of this "number as an object" kludge for all but really +-- big numbers, you can set JSON.decodeNumbersAsObjects and then also set one or both of +-- JSON:decodeIntegerObjectificationLength +-- JSON:decodeDecimalObjectificationLength +-- They refer to the length of the part of the number before and after a decimal point. If they are +-- set and their part is at least that number of digits, objectification occurs. If both are set, +-- objectification occurs when either length is met. +-- +-- ----------------------- +-- +-- Even without using the JSON.decodeNumbersAsObjects option, you can encode numbers in your Lua +-- table that retain high precision upon encoding to JSON, by using the JSON:asNumber() function: +-- +-- T = { +-- imprecise = 123456789123456789.123456789123456789, +-- precise = JSON:asNumber("123456789123456789.123456789123456789") +-- } +-- +-- print(JSON:encode_pretty(T)) +-- +-- This produces: +-- +-- { +-- "precise": 123456789123456789.123456789123456789, +-- "imprecise": 1.2345678912346e+17 +-- } +-- +-- +-- ----------------------- +-- +-- A different way to handle big/precise JSON numbers is to have decode() merely return the exact +-- string representation of the number instead of the number itself. This approach might be useful +-- when the numbers are merely some kind of opaque object identifier and you want to work with them +-- in Lua as strings anyway. +-- +-- This approach is enabled by setting +-- +-- JSON.decodeIntegerStringificationLength = 10 +-- +-- The value is the number of digits (of the integer part of the number) at which to stringify numbers. +-- NOTE: this setting is ignored if JSON.decodeNumbersAsObjects is true, as that takes precedence. +-- +-- Consider our previous example with this option set to 10: +-- +-- JSON.decodeIntegerStringificationLength = 10 +-- +-- T = JSON:decode('{ "small":12345, "big":12345678901234567890123456789, "precise":9876.67890123456789012345 }') +-- +-- print("small: ", type(T.small), T.small) +-- print("big: ", type(T.big), T.big) +-- print("precise: ", type(T.precise), T.precise) +-- +-- This produces: +-- +-- small: number 12345 +-- big: string 12345678901234567890123456789 +-- precise: number 9876.6789012346 +-- +-- The long integer of the 'big' field is at least JSON.decodeIntegerStringificationLength digits +-- in length, so it's converted not to a Lua integer but to a Lua string. Using a value of 0 or 1 ensures +-- that all JSON numeric data becomes strings in Lua. +-- +-- Note that unlike +-- JSON.decodeNumbersAsObjects = true +-- this stringification is simple and unintelligent: the JSON number simply becomes a Lua string, and that's the end of it. +-- If the string is then converted back to JSON, it's still a string. After running the code above, adding +-- print(JSON:encode(T)) +-- produces +-- {"big":"12345678901234567890123456789","precise":9876.6789012346,"small":12345} +-- which is unlikely to be desired. +-- +-- There's a comparable option for the length of the decimal part of a number: +-- +-- JSON.decodeDecimalStringificationLength +-- +-- This can be used alone or in conjunction with +-- +-- JSON.decodeIntegerStringificationLength +-- +-- to trip stringification on precise numbers with at least JSON.decodeIntegerStringificationLength digits after +-- the decimal point. (Both are ignored if JSON.decodeNumbersAsObjects is true.) +-- +-- This example: +-- +-- JSON.decodeIntegerStringificationLength = 10 +-- JSON.decodeDecimalStringificationLength = 5 +-- +-- T = JSON:decode('{ "small":12345, "big":12345678901234567890123456789, "precise":9876.67890123456789012345 }') +-- +-- print("small: ", type(T.small), T.small) +-- print("big: ", type(T.big), T.big) +-- print("precise: ", type(T.precise), T.precise) +-- +-- produces: +-- +-- small: number 12345 +-- big: string 12345678901234567890123456789 +-- precise: string 9876.67890123456789012345 +-- +-- +-- HANDLING UNSUPPORTED VALUE TYPES +-- +-- Among the encoding errors that might be raised is an attempt to convert a table value that has a type +-- that this package hasn't accounted for: a function, userdata, or a thread. You can handle these types as table +-- values (but not as table keys) if you supply a JSON:unsupportedTypeEncoder() method along the lines of the +-- following example: +-- +-- function JSON:unsupportedTypeEncoder(value_of_unsupported_type) +-- if type(value_of_unsupported_type) == 'function' then +-- return "a function value" +-- else +-- return nil +-- end +-- end +-- +-- Your unsupportedTypeEncoder() method is actually called with a bunch of arguments: +-- +-- self:unsupportedTypeEncoder(value, parents, etc, options, indent, for_key) +-- +-- The 'value' is the function, thread, or userdata to be converted to JSON. +-- +-- The 'etc' and 'options' arguments are those passed to the original encode(). The other arguments are +-- probably of little interest; see the source code. (Note that 'for_key' is never true, as this function +-- is invoked only on table values; table keys of these types still trigger the onEncodeError method.) +-- +-- If your unsupportedTypeEncoder() method returns a string, it's inserted into the JSON as is. +-- If it returns nil plus an error message, that error message is passed through to an onEncodeError invocation. +-- If it returns only nil, processing falls through to a default onEncodeError invocation. +-- +-- If you want to handle everything in a simple way: +-- +-- function JSON:unsupportedTypeEncoder(value) +-- return tostring(value) +-- end +-- +-- +-- SUMMARY OF METHODS YOU CAN OVERRIDE IN YOUR LOCAL LUA JSON OBJECT +-- +-- assert +-- onDecodeError +-- onDecodeOfNilError +-- onDecodeOfHTMLError +-- onTrailingGarbage +-- onEncodeError +-- unsupportedTypeEncoder +-- +-- If you want to create a separate Lua JSON object with its own error handlers, +-- you can reload JSON.lua or use the :new() method. +-- +--------------------------------------------------------------------------- + +local default_pretty_indent = " " +local default_pretty_options = { pretty = true, indent = default_pretty_indent, align_keys = false, array_newline = false } + +local isArray = { __tostring = function() return "JSON array" end } isArray.__index = isArray +local isObject = { __tostring = function() return "JSON object" end } isObject.__index = isObject + +function OBJDEF:newArray(tbl) + return setmetatable(tbl or {}, isArray) +end + +function OBJDEF:newObject(tbl) + return setmetatable(tbl or {}, isObject) +end + + + + +local function getnum(op) + return type(op) == 'number' and op or op.N +end + +local isNumber = { + __tostring = function(T) return T.S end, + __unm = function(op) return getnum(op) end, + + __concat = function(op1, op2) return tostring(op1) .. tostring(op2) end, + __add = function(op1, op2) return getnum(op1) + getnum(op2) end, + __sub = function(op1, op2) return getnum(op1) - getnum(op2) end, + __mul = function(op1, op2) return getnum(op1) * getnum(op2) end, + __div = function(op1, op2) return getnum(op1) / getnum(op2) end, + __mod = function(op1, op2) return getnum(op1) % getnum(op2) end, + __pow = function(op1, op2) return getnum(op1) ^ getnum(op2) end, + __lt = function(op1, op2) return getnum(op1) < getnum(op2) end, + __eq = function(op1, op2) return getnum(op1) == getnum(op2) end, + __le = function(op1, op2) return getnum(op1) <= getnum(op2) end, +} +isNumber.__index = isNumber + +function OBJDEF:asNumber(item) + + if getmetatable(item) == isNumber then + -- it's already a JSON number object. + return item + elseif type(item) == 'table' and type(item.S) == 'string' and type(item.N) == 'number' then + -- it's a number-object table that lost its metatable, so give it one + return setmetatable(item, isNumber) + else + -- the normal situation... given a number or a string representation of a number.... + local holder = { + S = tostring(item), -- S is the representation of the number as a string, which remains precise + N = tonumber(item), -- N is the number as a Lua number. + } + return setmetatable(holder, isNumber) + end +end + +-- +-- Given an item that might be a normal string or number, or might be an 'isNumber' object defined above, +-- return the string version. This shouldn't be needed often because the 'isNumber' object should autoconvert +-- to a string in most cases, but it's here to allow it to be forced when needed. +-- +function OBJDEF:forceString(item) + if type(item) == 'table' and type(item.S) == 'string' then + return item.S + else + return tostring(item) + end +end + +-- +-- Given an item that might be a normal string or number, or might be an 'isNumber' object defined above, +-- return the numeric version. +-- +function OBJDEF:forceNumber(item) + if type(item) == 'table' and type(item.N) == 'number' then + return item.N + else + return tonumber(item) + end +end + +-- +-- If the given item is a number, return it. Otherwise, return nil. +-- This, this can be used both in a conditional and to access the number when you're not sure its form. +-- +function OBJDEF:isNumber(item) + if type(item) == 'number' then + return item + elseif type(item) == 'table' and type(item.N) == 'number' then + return item.N + else + return nil + end +end + +function OBJDEF:isString(item) + if type(item) == 'string' then + return item + elseif type(item) == 'table' and type(item.S) == 'string' then + return item.S + else + return nil + end +end + + + + +-- +-- Some utf8 routines to deal with the fact that Lua handles only bytes +-- +local function top_three_bits(val) + return math.floor(val / 0x20) +end + +local function top_four_bits(val) + return math.floor(val / 0x10) +end + +local function unicode_character_bytecount_based_on_first_byte(first_byte) + local W = string.byte(first_byte) + if W < 0x80 then + return 1 + elseif (W == 0xC0) or (W == 0xC1) or (W >= 0x80 and W <= 0xBF) or (W >= 0xF5) then + -- this is an error -- W can't be the start of a utf8 character + return 0 + elseif top_three_bits(W) == 0x06 then + return 2 + elseif top_four_bits(W) == 0x0E then + return 3 + else + return 4 + end +end + + + +local function unicode_codepoint_as_utf8(codepoint) + -- + -- codepoint is a number + -- + if codepoint <= 127 then + return string.char(codepoint) + + elseif codepoint <= 2047 then + -- + -- 110yyyxx 10xxxxxx <-- useful notation from http://en.wikipedia.org/wiki/Utf8 + -- + local highpart = math.floor(codepoint / 0x40) + local lowpart = codepoint - (0x40 * highpart) + return string.char(0xC0 + highpart, + 0x80 + lowpart) + + elseif codepoint <= 65535 then + -- + -- 1110yyyy 10yyyyxx 10xxxxxx + -- + local highpart = math.floor(codepoint / 0x1000) + local remainder = codepoint - 0x1000 * highpart + local midpart = math.floor(remainder / 0x40) + local lowpart = remainder - 0x40 * midpart + + highpart = 0xE0 + highpart + midpart = 0x80 + midpart + lowpart = 0x80 + lowpart + + -- + -- Check for an invalid character (thanks Andy R. at Adobe). + -- See table 3.7, page 93, in http://www.unicode.org/versions/Unicode5.2.0/ch03.pdf#G28070 + -- + if ( highpart == 0xE0 and midpart < 0xA0 ) or + ( highpart == 0xED and midpart > 0x9F ) or + ( highpart == 0xF0 and midpart < 0x90 ) or + ( highpart == 0xF4 and midpart > 0x8F ) + then + return "?" + else + return string.char(highpart, + midpart, + lowpart) + end + + else + -- + -- 11110zzz 10zzyyyy 10yyyyxx 10xxxxxx + -- + local highpart = math.floor(codepoint / 0x40000) + local remainder = codepoint - 0x40000 * highpart + local midA = math.floor(remainder / 0x1000) + remainder = remainder - 0x1000 * midA + local midB = math.floor(remainder / 0x40) + local lowpart = remainder - 0x40 * midB + + return string.char(0xF0 + highpart, + 0x80 + midA, + 0x80 + midB, + 0x80 + lowpart) + end +end + +function OBJDEF:onDecodeError(message, text, location, etc) + if text then + if location then + message = string.format("%s at byte %d of: %s", message, location, text) + else + message = string.format("%s: %s", message, text) + end + end + + if etc ~= nil then + message = message .. " (" .. OBJDEF:encode(etc) .. ")" + end + + if self.assert then + self.assert(false, message) + else + assert(false, message) + end +end + +function OBJDEF:onTrailingGarbage(json_text, location, parsed_value, etc) + return self:onDecodeError("trailing garbage", json_text, location, etc) +end + +OBJDEF.onDecodeOfNilError = OBJDEF.onDecodeError +OBJDEF.onDecodeOfHTMLError = OBJDEF.onDecodeError + +function OBJDEF:onEncodeError(message, etc) + if etc ~= nil then + message = message .. " (" .. OBJDEF:encode(etc) .. ")" + end + + if self.assert then + self.assert(false, message) + else + assert(false, message) + end +end + +local function grok_number(self, text, start, options) + -- + -- Grab the integer part + -- + local integer_part = text:match('^-?[1-9]%d*', start) + or text:match("^-?0", start) + + if not integer_part then + self:onDecodeError("expected number", text, start, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + end + + local i = start + integer_part:len() + + -- + -- Grab an optional decimal part + -- + local decimal_part = text:match('^%.%d+', i) or "" + + i = i + decimal_part:len() + + -- + -- Grab an optional exponential part + -- + local exponent_part = text:match('^[eE][-+]?%d+', i) or "" + + i = i + exponent_part:len() + + local full_number_text = integer_part .. decimal_part .. exponent_part + + if options.decodeNumbersAsObjects then + + local objectify = false + + if not options.decodeIntegerObjectificationLength and not options.decodeDecimalObjectificationLength then + -- no options, so objectify + objectify = true + + elseif (options.decodeIntegerObjectificationLength + and + (integer_part:len() >= options.decodeIntegerObjectificationLength or exponent_part:len() > 0)) + + or + (options.decodeDecimalObjectificationLength + and + (decimal_part:len() >= options.decodeDecimalObjectificationLength or exponent_part:len() > 0)) + then + -- have options and they are triggered, so objectify + objectify = true + end + + if objectify then + return OBJDEF:asNumber(full_number_text), i + end + -- else, fall through to try to return as a straight-up number + + else + + -- Not always decoding numbers as objects, so perhaps encode as strings? + + -- + -- If we're told to stringify only under certain conditions, so do. + -- We punt a bit when there's an exponent by just stringifying no matter what. + -- I suppose we should really look to see whether the exponent is actually big enough one + -- way or the other to trip stringification, but I'll be lazy about it until someone asks. + -- + if (options.decodeIntegerStringificationLength + and + (integer_part:len() >= options.decodeIntegerStringificationLength or exponent_part:len() > 0)) + + or + + (options.decodeDecimalStringificationLength + and + (decimal_part:len() >= options.decodeDecimalStringificationLength or exponent_part:len() > 0)) + then + return full_number_text, i -- this returns the exact string representation seen in the original JSON + end + + end + + + local as_number = tonumber(full_number_text) + + if not as_number then + self:onDecodeError("bad number", text, start, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + end + + return as_number, i +end + + +local backslash_escape_conversion = { + ['"'] = '"', + ['/'] = "/", + ['\\'] = "\\", + ['b'] = "\b", + ['f'] = "\f", + ['n'] = "\n", + ['r'] = "\r", + ['t'] = "\t", +} + +local function grok_string(self, text, start, options) + + if text:sub(start,start) ~= '"' then + self:onDecodeError("expected string's opening quote", text, start, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + end + + local i = start + 1 -- +1 to bypass the initial quote + local text_len = text:len() + local VALUE = "" + while i <= text_len do + local c = text:sub(i,i) + if c == '"' then + return VALUE, i + 1 + end + if c ~= '\\' then + + -- should grab the next bytes as per the number of bytes for this utf8 character + local byte_count = unicode_character_bytecount_based_on_first_byte(c) + + local next_character + if byte_count == 0 then + self:onDecodeError("non-utf8 sequence", text, i, options.etc) + elseif byte_count == 1 then + if options.strictParsing and string.byte(c) < 0x20 then + self:onDecodeError("Unescaped control character", text, i+1, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + end + next_character = c + elseif byte_count == 2 then + next_character = text:match('^(.[\128-\191])', i) + elseif byte_count == 3 then + next_character = text:match('^(.[\128-\191][\128-\191])', i) + elseif byte_count == 4 then + next_character = text:match('^(.[\128-\191][\128-\191][\128-\191])', i) + end + + if not next_character then + self:onDecodeError("incomplete utf8 sequence", text, i, options.etc) + return nil, i -- in case the error method doesn't abort, return something sensible + end + + + VALUE = VALUE .. next_character + i = i + byte_count + + else + -- + -- We have a backslash escape + -- + i = i + 1 + + local next_byte = text:match('^(.)', i) + + if next_byte == nil then + -- string ended after the \ + self:onDecodeError("unfinished \\ escape", text, i, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + end + + if backslash_escape_conversion[next_byte] then + VALUE = VALUE .. backslash_escape_conversion[next_byte] + i = i + 1 + else + -- + -- The only other valid use of \ that remains is in the form of \u#### + -- + + local hex = text:match('^u([0123456789aAbBcCdDeEfF][0123456789aAbBcCdDeEfF][0123456789aAbBcCdDeEfF][0123456789aAbBcCdDeEfF])', i) + if hex then + i = i + 5 -- bypass what we just read + + -- We have a Unicode codepoint. It could be standalone, or if in the proper range and + -- followed by another in a specific range, it'll be a two-code surrogate pair. + local codepoint = tonumber(hex, 16) + if codepoint >= 0xD800 and codepoint <= 0xDBFF then + -- it's a hi surrogate... see whether we have a following low + local lo_surrogate = text:match('^\\u([dD][cdefCDEF][0123456789aAbBcCdDeEfF][0123456789aAbBcCdDeEfF])', i) + if lo_surrogate then + i = i + 6 -- bypass the low surrogate we just read + codepoint = 0x2400 + (codepoint - 0xD800) * 0x400 + tonumber(lo_surrogate, 16) + else + -- not a proper low, so we'll just leave the first codepoint as is and spit it out. + end + end + VALUE = VALUE .. unicode_codepoint_as_utf8(codepoint) + + elseif options.strictParsing then + --local next_byte = text:match('^\\(.)', i) printf("NEXT[%s]", next_byte); + self:onDecodeError("illegal use of backslash escape", text, i, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + else + local byte_count = unicode_character_bytecount_based_on_first_byte(next_byte) + if byte_count == 0 then + self:onDecodeError("non-utf8 sequence after backslash escape", text, i, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + end + + local next_character + if byte_count == 1 then + next_character = next_byte + elseif byte_count == 2 then + next_character = text:match('^(.[\128-\191])', i) + elseif byte_count == 3 then + next_character = text:match('^(.[\128-\191][\128-\191])', i) + elseif byte_count == 3 then + next_character = text:match('^(.[\128-\191][\128-\191][\128-\191])', i) + end + + if next_character == nil then + -- incomplete utf8 character after escape + self:onDecodeError("incomplete utf8 sequence after backslash escape", text, i, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + end + + VALUE = VALUE .. next_character + i = i + byte_count + end + end + end + end + + self:onDecodeError("unclosed string", text, start, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible +end + +local function skip_whitespace(text, start) + + local _, match_end = text:find("^[ \n\r\t]+", start) -- [ https://datatracker.ietf.org/doc/html/rfc7158#section-2 ] + if match_end then + return match_end + 1 + else + return start + end +end + +local grok_one -- assigned later + +local function grok_object(self, text, start, options) + + if text:sub(start,start) ~= '{' then + self:onDecodeError("expected '{'", text, start, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + end + + local i = skip_whitespace(text, start + 1) -- +1 to skip the '{' + + local VALUE = self.strictTypes and self:newObject { } or { } + + if text:sub(i,i) == '}' then + return VALUE, i + 1 + end + local text_len = text:len() + while i <= text_len do + local key, new_i = grok_string(self, text, i, options) + + i = skip_whitespace(text, new_i) + + if text:sub(i, i) ~= ':' then + self:onDecodeError("expected colon", text, i, options.etc) + return nil, i -- in case the error method doesn't abort, return something sensible + end + + i = skip_whitespace(text, i + 1) + + local new_val, new_i = grok_one(self, text, i, options) + + VALUE[key] = new_val + + -- + -- Expect now either '}' to end things, or a ',' to allow us to continue. + -- + i = skip_whitespace(text, new_i) + + local c = text:sub(i,i) + + if c == '}' then + return VALUE, i + 1 + end + + if text:sub(i, i) ~= ',' then + self:onDecodeError("expected comma or '}'", text, i, options.etc) + return nil, i -- in case the error method doesn't abort, return something sensible + end + + i = skip_whitespace(text, i + 1) + end + + self:onDecodeError("unclosed '{'", text, start, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible +end + +local function grok_array(self, text, start, options) + if text:sub(start,start) ~= '[' then + self:onDecodeError("expected '['", text, start, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + end + + local i = skip_whitespace(text, start + 1) -- +1 to skip the '[' + local VALUE = self.strictTypes and self:newArray { } or { } + if text:sub(i,i) == ']' then + return VALUE, i + 1 + end + + local VALUE_INDEX = 1 + + local text_len = text:len() + while i <= text_len do + local val, new_i = grok_one(self, text, i, options) + + -- can't table.insert(VALUE, val) here because it's a no-op if val is nil + VALUE[VALUE_INDEX] = val + VALUE_INDEX = VALUE_INDEX + 1 + + i = skip_whitespace(text, new_i) + + -- + -- Expect now either ']' to end things, or a ',' to allow us to continue. + -- + local c = text:sub(i,i) + if c == ']' then + return VALUE, i + 1 + end + if text:sub(i, i) ~= ',' then + self:onDecodeError("expected comma or ']'", text, i, options.etc) + return nil, i -- in case the error method doesn't abort, return something sensible + end + i = skip_whitespace(text, i + 1) + end + self:onDecodeError("unclosed '['", text, start, options.etc) + return nil, i -- in case the error method doesn't abort, return something sensible +end + + +grok_one = function(self, text, start, options) + -- Skip any whitespace + start = skip_whitespace(text, start) + + if start > text:len() then + self:onDecodeError("unexpected end of string", text, nil, options.etc) + return nil, start -- in case the error method doesn't abort, return something sensible + end + + if text:find('^"', start) then + return grok_string(self, text, start, options) + + elseif text:find('^[-0123456789 ]', start) then + return grok_number(self, text, start, options) + + elseif text:find('^%{', start) then + return grok_object(self, text, start, options) + + elseif text:find('^%[', start) then + return grok_array(self, text, start, options) + + elseif text:find('^true', start) then + return true, start + 4 + + elseif text:find('^false', start) then + return false, start + 5 + + elseif text:find('^null', start) then + return options.null, start + 4 + + else + self:onDecodeError("can't parse JSON", text, start, options.etc) + return nil, 1 -- in case the error method doesn't abort, return something sensible + end +end + +function OBJDEF:decode(text, etc, options) + -- + -- If the user didn't pass in a table of decode options, make an empty one. + -- + if type(options) ~= 'table' then + options = {} + end + + -- + -- If they passed in an 'etc' argument, stuff it into the options. + -- (If not, any 'etc' field in the options they passed in remains to be used) + -- + if etc ~= nil then + options.etc = etc + end + + + -- + -- apply global options + -- + if options.decodeNumbersAsObjects == nil then + options.decodeNumbersAsObjects = self.decodeNumbersAsObjects + end + if options.decodeIntegerObjectificationLength == nil then + options.decodeIntegerObjectificationLength = self.decodeIntegerObjectificationLength + end + if options.decodeDecimalObjectificationLength == nil then + options.decodeDecimalObjectificationLength = self.decodeDecimalObjectificationLength + end + if options.decodeIntegerStringificationLength == nil then + options.decodeIntegerStringificationLength = self.decodeIntegerStringificationLength + end + if options.decodeDecimalStringificationLength == nil then + options.decodeDecimalStringificationLength = self.decodeDecimalStringificationLength + end + if options.strictParsing == nil then + options.strictParsing = self.strictParsing + end + + + if type(self) ~= 'table' or self.__index ~= OBJDEF then + local error_message = "JSON:decode must be called in method format" + OBJDEF:onDecodeError(error_message, nil, nil, options.etc) + return nil, error_message -- in case the error method doesn't abort, return something sensible + end + + if text == nil then + local error_message = "nil passed to JSON:decode()" + self:onDecodeOfNilError(error_message, nil, nil, options.etc) + return nil, error_message -- in case the error method doesn't abort, return something sensible + + elseif type(text) ~= 'string' then + local error_message = "expected string argument to JSON:decode()" + self:onDecodeError(string.format("%s, got %s", error_message, type(text)), nil, nil, options.etc) + return nil, error_message -- in case the error method doesn't abort, return something sensible + end + + -- If passed an empty string.... + if text:match('^%s*$') then + if options.strictParsing then + local error_message = "empty string passed to JSON:decode()" + self:onDecodeOfNilError(error_message, nil, nil, options.etc) + return nil, error_message -- in case the error method doesn't abort, return something sensible + else + -- we'll consider it nothing, but not an error + return nil + end + end + + if text:match('^%s*<') then + -- Can't be JSON... we'll assume it's HTML + local error_message = "HTML passed to JSON:decode()" + self:onDecodeOfHTMLError(error_message, text, nil, options.etc) + return nil, error_message -- in case the error method doesn't abort, return something sensible + end + + -- + -- Ensure that it's not UTF-32 or UTF-16. + -- Those are perfectly valid encodings for JSON (as per RFC 4627 section 3), + -- but this package can't handle them. + -- + if text:sub(1,1):byte() == 0 or (text:len() >= 2 and text:sub(2,2):byte() == 0) then + local error_message = "JSON package groks only UTF-8, sorry" + self:onDecodeError(error_message, text, nil, options.etc) + return nil, error_message -- in case the error method doesn't abort, return something sensible + end + + + -- + -- Finally, go parse it + -- + local success, value, next_i = pcall(grok_one, self, text, 1, options) + + if success then + + local error_message = nil + if next_i ~= #text + 1 then + -- something's left over after we parsed the first thing.... whitespace is allowed. + next_i = skip_whitespace(text, next_i) + + -- if we have something left over now, it's trailing garbage + if next_i ~= #text + 1 then + value, error_message = self:onTrailingGarbage(text, next_i, value, options.etc) + end + end + return value, error_message + + else + + -- If JSON:onDecodeError() didn't abort out of the pcall, we'll have received + -- the error message here as "value", so pass it along as an assert. + local error_message = value + if self.assert then + self.assert(false, error_message) + else + assert(false, error_message) + end + -- ...and if we're still here (because the assert didn't throw an error), + -- return a nil and throw the error message on as a second arg + return nil, error_message + + end +end + +local function backslash_replacement_function(c) + if c == "\n" then return "\\n" + elseif c == "\r" then return "\\r" + elseif c == "\t" then return "\\t" + elseif c == "\b" then return "\\b" + elseif c == "\f" then return "\\f" + elseif c == '"' then return '\\"' + elseif c == '\\' then return '\\\\' + elseif c == '/' then return '/' + else + return string.format("\\u%04x", c:byte()) + end +end + +local chars_to_be_escaped_in_JSON_string + = '[' + .. '"' -- class sub-pattern to match a double quote + .. '%\\' -- class sub-pattern to match a backslash + .. '/' -- class sub-pattern to match a forwardslash + .. '%z' -- class sub-pattern to match a null + .. '\001' .. '-' .. '\031' -- class sub-pattern to match control characters + .. ']' + + +local LINE_SEPARATOR_as_utf8 = unicode_codepoint_as_utf8(0x2028) +local PARAGRAPH_SEPARATOR_as_utf8 = unicode_codepoint_as_utf8(0x2029) +local function json_string_literal(value, options) + local newval = value:gsub(chars_to_be_escaped_in_JSON_string, backslash_replacement_function) + if options.stringsAreUtf8 then + -- + -- This feels really ugly to just look into a string for the sequence of bytes that we know to be a particular utf8 character, + -- but utf8 was designed purposefully to make this kind of thing possible. Still, feels dirty. + -- I'd rather decode the byte stream into a character stream, but it's not technically needed so + -- not technically worth it. + -- + newval = newval:gsub(LINE_SEPARATOR_as_utf8, '\\u2028'):gsub(PARAGRAPH_SEPARATOR_as_utf8,'\\u2029') + end + return '"' .. newval .. '"' +end + +local function object_or_array(self, T, etc) + -- + -- We need to inspect all the keys... if there are any strings, we'll convert to a JSON + -- object. If there are only numbers, it's a JSON array. + -- + -- If we'll be converting to a JSON object, we'll want to sort the keys so that the + -- end result is deterministic. + -- + local string_keys = { } + local number_keys = { } + local number_keys_must_be_strings = false + local maximum_number_key + + for key in pairs(T) do + if type(key) == 'string' then + table.insert(string_keys, key) + elseif type(key) == 'number' then + table.insert(number_keys, key) + if key <= 0 or key >= math.huge then + number_keys_must_be_strings = true + elseif not maximum_number_key or key > maximum_number_key then + maximum_number_key = key + end + elseif type(key) == 'boolean' then + table.insert(string_keys, tostring(key)) + else + self:onEncodeError("can't encode table with a key of type " .. type(key), etc) + end + end + + if #string_keys == 0 and not number_keys_must_be_strings then + -- + -- An empty table, or a numeric-only array + -- + if #number_keys > 0 then + return nil, maximum_number_key -- an array + elseif tostring(T) == "JSON array" then + return nil + elseif tostring(T) == "JSON object" then + return { } + else + -- have to guess, so we'll pick array, since empty arrays are likely more common than empty objects + return nil + end + end + + table.sort(string_keys) + + local map + if #number_keys > 0 then + -- + -- If we're here then we have either mixed string/number keys, or numbers inappropriate for a JSON array + -- It's not ideal, but we'll turn the numbers into strings so that we can at least create a JSON object. + -- + + if self.noKeyConversion then + self:onEncodeError("a table with both numeric and string keys could be an object or array; aborting", etc) + end + + -- + -- Have to make a shallow copy of the source table so we can remap the numeric keys to be strings + -- + map = { } + for key, val in pairs(T) do + map[key] = val + end + + table.sort(number_keys) + + -- + -- Throw numeric keys in there as strings + -- + for _, number_key in ipairs(number_keys) do + local string_key = tostring(number_key) + if map[string_key] == nil then + table.insert(string_keys , string_key) + map[string_key] = T[number_key] + else + self:onEncodeError("conflict converting table with mixed-type keys into a JSON object: key " .. number_key .. " exists both as a string and a number.", etc) + end + end + end + + return string_keys, nil, map +end + +-- +-- Encode +-- +-- 'options' is nil, or a table with possible keys: +-- +-- pretty -- If true, return a pretty-printed version. +-- +-- indent -- A string (usually of spaces) used to indent each nested level. +-- +-- align_keys -- If true, align all the keys when formatting a table. The result is uglier than one might at first imagine. +-- Results are undefined if 'align_keys' is true but 'pretty' is not. +-- +-- array_newline -- If true, array elements are formatted each to their own line. The default is to all fall inline. +-- Results are undefined if 'array_newline' is true but 'pretty' is not. +-- +-- null -- If this exists with a string value, table elements with this value are output as JSON null. +-- +-- stringsAreUtf8 -- If true, consider Lua strings not as a sequence of bytes, but as a sequence of UTF-8 characters. +-- (Currently, the only practical effect of setting this option is that Unicode LINE and PARAGRAPH +-- separators, if found in a string, are encoded with a JSON escape instead of as raw UTF-8. +-- The JSON is valid either way, but encoding this way, apparently, allows the resulting JSON +-- to also be valid Java.) +-- +-- +local function encode_value(self, value, parents, etc, options, indent, for_key) + + -- + -- keys in a JSON object can never be null, so we don't even consider options.null when converting a key value + -- + if value == nil or (not for_key and options and options.null and value == options.null) then + return 'null' + + elseif type(value) == 'string' then + return json_string_literal(value, options) + + elseif type(value) == 'number' then + if value ~= value then + -- + -- NaN (Not a Number). + -- JSON has no NaN, so we have to fudge the best we can. This should really be a package option. + -- + return "null" + elseif value >= math.huge then + -- + -- Positive infinity. JSON has no INF, so we have to fudge the best we can. This should + -- really be a package option. Note: at least with some implementations, positive infinity + -- is both ">= math.huge" and "<= -math.huge", which makes no sense but that's how it is. + -- Negative infinity is properly "<= -math.huge". So, we must be sure to check the ">=" + -- case first. + -- + return "1e+9999" + elseif value <= -math.huge then + -- + -- Negative infinity. + -- JSON has no INF, so we have to fudge the best we can. This should really be a package option. + -- + return "-1e+9999" + else + return tostring(value) + end + + elseif type(value) == 'boolean' then + return tostring(value) + + elseif type(value) ~= 'table' then + + if self.unsupportedTypeEncoder then + local user_value, user_error = self:unsupportedTypeEncoder(value, parents, etc, options, indent, for_key) + -- If the user's handler returns a string, use that. If it returns nil plus an error message, bail with that. + -- If only nil returned, fall through to the default error handler. + if type(user_value) == 'string' then + return user_value + elseif user_value ~= nil then + self:onEncodeError("unsupportedTypeEncoder method returned a " .. type(user_value), etc) + elseif user_error then + self:onEncodeError(tostring(user_error), etc) + end + end + + self:onEncodeError("can't convert " .. type(value) .. " to JSON", etc) + + elseif getmetatable(value) == isNumber then + return tostring(value) + else + -- + -- A table to be converted to either a JSON object or array. + -- + local T = value + + if type(options) ~= 'table' then + options = {} + end + if type(indent) ~= 'string' then + indent = "" + end + + if parents[T] then + self:onEncodeError("table " .. tostring(T) .. " is a child of itself", etc) + else + parents[T] = true + end + + local result_value + + local object_keys, maximum_number_key, map = object_or_array(self, T, etc) + if maximum_number_key then + -- + -- An array... + -- + local key_indent + if options.array_newline then + key_indent = indent .. tostring(options.indent or "") + else + key_indent = indent + end + + local ITEMS = { } + for i = 1, maximum_number_key do + table.insert(ITEMS, encode_value(self, T[i], parents, etc, options, key_indent)) + end + + if options.array_newline then + result_value = "[\n" .. key_indent .. table.concat(ITEMS, ",\n" .. key_indent) .. "\n" .. indent .. "]" + elseif options.pretty then + result_value = "[ " .. table.concat(ITEMS, ", ") .. " ]" + else + result_value = "[" .. table.concat(ITEMS, ",") .. "]" + end + + elseif object_keys then + -- + -- An object + -- + local TT = map or T + + if options.pretty then + + local KEYS = { } + local max_key_length = 0 + for _, key in ipairs(object_keys) do + local encoded = encode_value(self, tostring(key), parents, etc, options, indent, true) + if options.align_keys then + max_key_length = math.max(max_key_length, #encoded) + end + table.insert(KEYS, encoded) + end + local key_indent = indent .. tostring(options.indent or "") + local subtable_indent = key_indent .. string.rep(" ", max_key_length) .. (options.align_keys and " " or "") + local FORMAT = "%s%" .. string.format("%d", max_key_length) .. "s: %s" + + local COMBINED_PARTS = { } + for i, key in ipairs(object_keys) do + local encoded_val = encode_value(self, TT[key], parents, etc, options, subtable_indent) + table.insert(COMBINED_PARTS, string.format(FORMAT, key_indent, KEYS[i], encoded_val)) + end + result_value = "{\n" .. table.concat(COMBINED_PARTS, ",\n") .. "\n" .. indent .. "}" + + else + + local PARTS = { } + for _, key in ipairs(object_keys) do + local encoded_val = encode_value(self, TT[key], parents, etc, options, indent) + local encoded_key = encode_value(self, tostring(key), parents, etc, options, indent, true) + table.insert(PARTS, string.format("%s:%s", encoded_key, encoded_val)) + end + result_value = "{" .. table.concat(PARTS, ",") .. "}" + + end + else + -- + -- An empty array/object... we'll treat it as an array, though it should really be an option + -- + result_value = "[]" + end + + parents[T] = false + return result_value + end +end + +local function top_level_encode(self, value, etc, options) + local val = encode_value(self, value, {}, etc, options) + if val == nil then + --PRIVATE("may need to revert to the previous public verison if I can't figure out what the guy wanted") + return val + else + return val + end +end + +function OBJDEF:encode(value, etc, options) + if type(self) ~= 'table' or self.__index ~= OBJDEF then + OBJDEF:onEncodeError("JSON:encode must be called in method format", etc) + end + + -- + -- If the user didn't pass in a table of decode options, make an empty one. + -- + if type(options) ~= 'table' then + options = {} + end + + return top_level_encode(self, value, etc, options) +end + +function OBJDEF:encode_pretty(value, etc, options) + if type(self) ~= 'table' or self.__index ~= OBJDEF then + OBJDEF:onEncodeError("JSON:encode_pretty must be called in method format", etc) + end + + -- + -- If the user didn't pass in a table of decode options, use the default pretty ones + -- + if type(options) ~= 'table' then + options = default_pretty_options + end + + return top_level_encode(self, value, etc, options) +end + +function OBJDEF.__tostring() + return "JSON encode/decode package" +end + +OBJDEF.__index = OBJDEF + +function OBJDEF:new(args) + local new = { } + + if args then + for key, val in pairs(args) do + new[key] = val + end + end + + return setmetatable(new, OBJDEF) +end + +return OBJDEF:new() + +-- +-- Version history: +-- +-- 20211016.28 Had forgotten to document the strictParsing option. +-- +-- 20211015.27 Better handle some edge-case errors [ thank you http://seriot.ch/projects/parsing_json.html ; all tests are now successful ] +-- +-- Added some semblance of proper UTF8 parsing, and now aborts with an error on ilformatted UTF8. +-- +-- Added the strictParsing option: +-- Aborts with an error on unknown backslash-escape in strings +-- Aborts on naked control characters in strings +-- Aborts when decode is passed a whitespace-only string +-- +-- For completeness, when encoding a Lua string into a JSON string, escape a forward slash. +-- +-- String decoding should be a bit more efficient now. +-- +-- 20170927.26 Use option.null in decoding as well. Thanks to Max Sindwani for the bump, and sorry to Oliver Hitz +-- whose first mention of it four years ago was completely missed by me. +-- +-- 20170823.25 Added support for JSON:unsupportedTypeEncoder(). +-- Thanks to Chronos Phaenon Eosphoros (https://github.com/cpeosphoros) for the idea. +-- +-- 20170819.24 Added support for boolean keys in tables. +-- +-- 20170416.23 Added the "array_newline" formatting option suggested by yurenchen (http://www.yurenchen.com/) +-- +-- 20161128.22 Added: +-- JSON:isString() +-- JSON:isNumber() +-- JSON:decodeIntegerObjectificationLength +-- JSON:decodeDecimalObjectificationLength +-- +-- 20161109.21 Oops, had a small boo-boo in the previous update. +-- +-- 20161103.20 Used to silently ignore trailing garbage when decoding. Now fails via JSON:onTrailingGarbage() +-- http://seriot.ch/parsing_json.php +-- +-- Built-in error message about "expected comma or ']'" had mistakenly referred to '[' +-- +-- Updated the built-in error reporting to refer to bytes rather than characters. +-- +-- The decode() method no longer assumes that error handlers abort. +-- +-- Made the VERSION string a string instead of a number +-- + +-- 20160916.19 Fixed the isNumber.__index assignment (thanks to Jack Taylor) +-- +-- 20160730.18 Added JSON:forceString() and JSON:forceNumber() +-- +-- 20160728.17 Added concatenation to the metatable for JSON:asNumber() +-- +-- 20160709.16 Could crash if not passed an options table (thanks jarno heikkinen ). +-- +-- Made JSON:asNumber() a bit more resilient to being passed the results of itself. +-- +-- 20160526.15 Added the ability to easily encode null values in JSON, via the new "null" encoding option. +-- (Thanks to Adam B for bringing up the issue.) +-- +-- Added some support for very large numbers and precise floats via +-- JSON.decodeNumbersAsObjects +-- JSON.decodeIntegerStringificationLength +-- JSON.decodeDecimalStringificationLength +-- +-- Added the "stringsAreUtf8" encoding option. (Hat tip to http://lua-users.org/wiki/JsonModules ) +-- +-- 20141223.14 The encode_pretty() routine produced fine results for small datasets, but isn't really +-- appropriate for anything large, so with help from Alex Aulbach I've made the encode routines +-- more flexible, and changed the default encode_pretty() to be more generally useful. +-- +-- Added a third 'options' argument to the encode() and encode_pretty() routines, to control +-- how the encoding takes place. +-- +-- Updated docs to add assert() call to the loadfile() line, just as good practice so that +-- if there is a problem loading JSON.lua, the appropriate error message will percolate up. +-- +-- 20140920.13 Put back (in a way that doesn't cause warnings about unused variables) the author string, +-- so that the source of the package, and its version number, are visible in compiled copies. +-- +-- 20140911.12 Minor lua cleanup. +-- Fixed internal reference to 'JSON.noKeyConversion' to reference 'self' instead of 'JSON'. +-- (Thanks to SmugMug's David Parry for these.) +-- +-- 20140418.11 JSON nulls embedded within an array were being ignored, such that +-- ["1",null,null,null,null,null,"seven"], +-- would return +-- {1,"seven"} +-- It's now fixed to properly return +-- {1, nil, nil, nil, nil, nil, "seven"} +-- Thanks to "haddock" for catching the error. +-- +-- 20140116.10 The user's JSON.assert() wasn't always being used. Thanks to "blue" for the heads up. +-- +-- 20131118.9 Update for Lua 5.3... it seems that tostring(2/1) produces "2.0" instead of "2", +-- and this caused some problems. +-- +-- 20131031.8 Unified the code for encode() and encode_pretty(); they had been stupidly separate, +-- and had of course diverged (encode_pretty didn't get the fixes that encode got, so +-- sometimes produced incorrect results; thanks to Mattie for the heads up). +-- +-- Handle encoding tables with non-positive numeric keys (unlikely, but possible). +-- +-- If a table has both numeric and string keys, or its numeric keys are inappropriate +-- (such as being non-positive or infinite), the numeric keys are turned into +-- string keys appropriate for a JSON object. So, as before, +-- JSON:encode({ "one", "two", "three" }) +-- produces the array +-- ["one","two","three"] +-- but now something with mixed key types like +-- JSON:encode({ "one", "two", "three", SOMESTRING = "some string" })) +-- instead of throwing an error produces an object: +-- {"1":"one","2":"two","3":"three","SOMESTRING":"some string"} +-- +-- To maintain the prior throw-an-error semantics, set +-- JSON.noKeyConversion = true +-- +-- 20131004.7 Release under a Creative Commons CC-BY license, which I should have done from day one, sorry. +-- +-- 20130120.6 Comment update: added a link to the specific page on my blog where this code can +-- be found, so that folks who come across the code outside of my blog can find updates +-- more easily. +-- +-- 20111207.5 Added support for the 'etc' arguments, for better error reporting. +-- +-- 20110731.4 More feedback from David Kolf on how to make the tests for Nan/Infinity system independent. +-- +-- 20110730.3 Incorporated feedback from David Kolf at http://lua-users.org/wiki/JsonModules: +-- +-- * When encoding lua for JSON, Sparse numeric arrays are now handled by +-- spitting out full arrays, such that +-- JSON:encode({"one", "two", [10] = "ten"}) +-- returns +-- ["one","two",null,null,null,null,null,null,null,"ten"] +-- +-- In 20100810.2 and earlier, only up to the first non-null value would have been retained. +-- +-- * When encoding lua for JSON, numeric value NaN gets spit out as null, and infinity as "1+e9999". +-- Version 20100810.2 and earlier created invalid JSON in both cases. +-- +-- * Unicode surrogate pairs are now detected when decoding JSON. +-- +-- 20100810.2 added some checking to ensure that an invalid Unicode character couldn't leak in to the UTF-8 encoding +-- +-- 20100731.1 initial public release +-- diff --git a/engine/lua/lust.lua b/engine/lua/lust.lua new file mode 100644 index 000000000..9083ec68f --- /dev/null +++ b/engine/lua/lust.lua @@ -0,0 +1,269 @@ +-- lust v0.2.0 - Lua test framework +-- https://github.com/bjornbytes/lust +-- MIT LICENSE + +--[[ +Copyright (c) 2016 Bjorn Swenson + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +--]] + +--[[ +Source: https://github.com/bjornbytes/lust/commit/7b4f12844e0e00dfc501b67f5a644232c3b275fe +Modifications: +- Embed LICENSE into this file. +- Add this comment. +--]] + +local lust = {} +lust.level = 0 +lust.passes = 0 +lust.errors = 0 +lust.befores = {} +lust.afters = {} + +local red = string.char(27) .. '[31m' +local green = string.char(27) .. '[32m' +local normal = string.char(27) .. '[0m' +local function indent(level) return string.rep('\t', level or lust.level) end + +function lust.nocolor() + red, green, normal = '', '', '' + return lust +end + +function lust.describe(name, fn) + print(indent() .. name) + lust.level = lust.level + 1 + fn() + lust.befores[lust.level] = {} + lust.afters[lust.level] = {} + lust.level = lust.level - 1 +end + +function lust.it(name, fn) + for level = 1, lust.level do + if lust.befores[level] then + for i = 1, #lust.befores[level] do + lust.befores[level][i](name) + end + end + end + + local success, err = pcall(fn) + if success then lust.passes = lust.passes + 1 + else lust.errors = lust.errors + 1 end + local color = success and green or red + local label = success and 'PASS' or 'FAIL' + print(indent() .. color .. label .. normal .. ' ' .. name) + if err then + print(indent(lust.level + 1) .. red .. tostring(err) .. normal) + end + + for level = 1, lust.level do + if lust.afters[level] then + for i = 1, #lust.afters[level] do + lust.afters[level][i](name) + end + end + end +end + +function lust.before(fn) + lust.befores[lust.level] = lust.befores[lust.level] or {} + table.insert(lust.befores[lust.level], fn) +end + +function lust.after(fn) + lust.afters[lust.level] = lust.afters[lust.level] or {} + table.insert(lust.afters[lust.level], fn) +end + +-- Assertions +local function isa(v, x) + if type(x) == 'string' then + return type(v) == x, + 'expected ' .. tostring(v) .. ' to be a ' .. x, + 'expected ' .. tostring(v) .. ' to not be a ' .. x + elseif type(x) == 'table' then + if type(v) ~= 'table' then + return false, + 'expected ' .. tostring(v) .. ' to be a ' .. tostring(x), + 'expected ' .. tostring(v) .. ' to not be a ' .. tostring(x) + end + + local seen = {} + local meta = v + while meta and not seen[meta] do + if meta == x then return true end + seen[meta] = true + meta = getmetatable(meta) and getmetatable(meta).__index + end + + return false, + 'expected ' .. tostring(v) .. ' to be a ' .. tostring(x), + 'expected ' .. tostring(v) .. ' to not be a ' .. tostring(x) + end + + error('invalid type ' .. tostring(x)) +end + +local function has(t, x) + for k, v in pairs(t) do + if v == x then return true end + end + return false +end + +local function strict_eq(t1, t2) + if type(t1) ~= type(t2) then return false end + if type(t1) ~= 'table' then return t1 == t2 end + for k, _ in pairs(t1) do + if not strict_eq(t1[k], t2[k]) then return false end + end + for k, _ in pairs(t2) do + if not strict_eq(t2[k], t1[k]) then return false end + end + return true +end + +local paths = { + [''] = { 'to', 'to_not' }, + to = { 'have', 'equal', 'be', 'exist', 'fail', 'match' }, + to_not = { 'have', 'equal', 'be', 'exist', 'fail', 'match', chain = function(a) a.negate = not a.negate end }, + a = { test = isa }, + an = { test = isa }, + be = { 'a', 'an', 'truthy', + test = function(v, x) + return v == x, + 'expected ' .. tostring(v) .. ' and ' .. tostring(x) .. ' to be equal', + 'expected ' .. tostring(v) .. ' and ' .. tostring(x) .. ' to not be equal' + end + }, + exist = { + test = function(v) + return v ~= nil, + 'expected ' .. tostring(v) .. ' to exist', + 'expected ' .. tostring(v) .. ' to not exist' + end + }, + truthy = { + test = function(v) + return v, + 'expected ' .. tostring(v) .. ' to be truthy', + 'expected ' .. tostring(v) .. ' to not be truthy' + end + }, + equal = { + test = function(v, x) + return strict_eq(v, x), + 'expected ' .. tostring(v) .. ' and ' .. tostring(x) .. ' to be exactly equal', + 'expected ' .. tostring(v) .. ' and ' .. tostring(x) .. ' to not be exactly equal' + end + }, + have = { + test = function(v, x) + if type(v) ~= 'table' then + error('expected ' .. tostring(v) .. ' to be a table') + end + + return has(v, x), + 'expected ' .. tostring(v) .. ' to contain ' .. tostring(x), + 'expected ' .. tostring(v) .. ' to not contain ' .. tostring(x) + end + }, + fail = { + test = function(v) + return not pcall(v), + 'expected ' .. tostring(v) .. ' to fail', + 'expected ' .. tostring(v) .. ' to not fail' + end + }, + match = { + test = function(v, p) + if type(v) ~= 'string' then v = tostring(v) end + local result = string.find(v, p) + return result ~= nil, + 'expected ' .. v .. ' to match pattern [[' .. p .. ']]', + 'expected ' .. v .. ' to not match pattern [[' .. p .. ']]' + end + }, +} + +function lust.expect(v) + local assertion = {} + assertion.val = v + assertion.action = '' + assertion.negate = false + + setmetatable(assertion, { + __index = function(t, k) + if has(paths[rawget(t, 'action')], k) then + rawset(t, 'action', k) + local chain = paths[rawget(t, 'action')].chain + if chain then chain(t) end + return t + end + return rawget(t, k) + end, + __call = function(t, ...) + if paths[t.action].test then + local res, err, nerr = paths[t.action].test(t.val, ...) + if assertion.negate then + res = not res + err = nerr or err + end + if not res then + error(err or 'unknown failure', 2) + end + end + end + }) + + return assertion +end + +function lust.spy(target, name, run) + local spy = {} + local subject + + local function capture(...) + table.insert(spy, {...}) + return subject(...) + end + + if type(target) == 'table' then + subject = target[name] + target[name] = capture + else + run = name + subject = target or function() end + end + + setmetatable(spy, {__call = function(_, ...) return capture(...) end}) + + if run then run() end + + return spy +end + +lust.test = lust.it +lust.paths = paths + +return lust diff --git a/engine/lua/tableshape.lua b/engine/lua/tableshape.lua new file mode 100644 index 000000000..5fedaee6e --- /dev/null +++ b/engine/lua/tableshape.lua @@ -0,0 +1,2354 @@ +--[[ + Source: https://github.com/leafo/tableshape/blob/v2.6.0/tableshape/init.lua + Modifications: + - Rename tableshape/init.lua -> tableshape.lua. + - Embed MIT license from README.md into this file. + - Add this comment. +]] + +--[[ + Copyright (C) 2022 by Leaf Corcoran + + Permission is hereby granted, free of charge, to any person obtaining a copy of + this software and associated documentation files (the "Software"), to deal in + the Software without restriction, including without limitation the rights to + use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + of the Software, and to permit persons to whom the Software is furnished to do + so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +]] + +local OptionalType, TaggedType, types, is_type +local BaseType, TransformNode, SequenceNode, FirstOfNode, DescribeNode, NotType, Literal +local FailedTransform = { } +local unpack = unpack or table.unpack +local clone_state +clone_state = function(state_obj) + if type(state_obj) ~= "table" then + return { } + end + local out + do + local _tbl_0 = { } + for k, v in pairs(state_obj) do + _tbl_0[k] = v + end + out = _tbl_0 + end + do + local mt = getmetatable(state_obj) + if mt then + setmetatable(out, mt) + end + end + return out +end +local describe_type +describe_type = function(val) + if type(val) == "string" then + if not val:match('"') then + return "\"" .. tostring(val) .. "\"" + elseif not val:match("'") then + return "'" .. tostring(val) .. "'" + else + return "`" .. tostring(val) .. "`" + end + elseif BaseType:is_base_type(val) then + return val:_describe() + else + return tostring(val) + end +end +local coerce_literal +coerce_literal = function(value) + local _exp_0 = type(value) + if "string" == _exp_0 or "number" == _exp_0 or "boolean" == _exp_0 then + return Literal(value) + elseif "table" == _exp_0 then + if BaseType:is_base_type(value) then + return value + end + end + return nil, "failed to coerce literal into type, use types.literal() to test for literal value" +end +local join_names +join_names = function(items, sep, last_sep) + if sep == nil then + sep = ", " + end + local count = #items + local chunks = { } + for idx, name in ipairs(items) do + if idx > 1 then + local current_sep + if idx == count then + current_sep = last_sep or sep + else + current_sep = sep + end + table.insert(chunks, current_sep) + end + table.insert(chunks, name) + end + return table.concat(chunks) +end +do + local _class_0 + local _base_0 = { + __div = function(self, fn) + return TransformNode(self, fn) + end, + __mod = function(self, fn) + do + local _with_0 = TransformNode(self, fn) + _with_0.with_state = true + return _with_0 + end + end, + __mul = function(_left, _right) + local left, err = coerce_literal(_left) + if not (left) then + error("left hand side of multiplication: " .. tostring(_left) .. ": " .. tostring(err)) + end + local right + right, err = coerce_literal(_right) + if not (right) then + error("right hand side of multiplication: " .. tostring(_right) .. ": " .. tostring(err)) + end + return SequenceNode(left, right) + end, + __add = function(_left, _right) + local left, err = coerce_literal(_left) + if not (left) then + error("left hand side of addition: " .. tostring(_left) .. ": " .. tostring(err)) + end + local right + right, err = coerce_literal(_right) + if not (right) then + error("right hand side of addition: " .. tostring(_right) .. ": " .. tostring(err)) + end + if left.__class == FirstOfNode then + local options = { + unpack(left.options) + } + table.insert(options, right) + return FirstOfNode(unpack(options)) + elseif right.__class == FirstOfNode then + return FirstOfNode(left, unpack(right.options)) + else + return FirstOfNode(left, right) + end + end, + __unm = function(self, right) + return NotType(right) + end, + __tostring = function(self) + return self:_describe() + end, + _describe = function(self) + return error("Node missing _describe: " .. tostring(self.__class.__name)) + end, + check_value = function(self, ...) + local value, state_or_err = self:_transform(...) + if value == FailedTransform then + return nil, state_or_err + end + if type(state_or_err) == "table" then + return state_or_err + else + return true + end + end, + transform = function(self, ...) + local value, state_or_err = self:_transform(...) + if value == FailedTransform then + return nil, state_or_err + end + if type(state_or_err) == "table" then + return value, state_or_err + else + return value + end + end, + repair = function(self, ...) + return self:transform(...) + end, + on_repair = function(self, fn) + return (self + types.any / fn * self):describe(function() + return self:_describe() + end) + end, + is_optional = function(self) + return OptionalType(self) + end, + describe = function(self, ...) + return DescribeNode(self, ...) + end, + tag = function(self, name) + return TaggedType(self, { + tag = name + }) + end, + clone_opts = function(self) + return error("clone_opts is not longer supported") + end, + __call = function(self, ...) + return self:check_value(...) + end + } + _base_0.__index = _base_0 + _class_0 = setmetatable({ + __init = function(self, opts) end, + __base = _base_0, + __name = "BaseType" + }, { + __index = _base_0, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + local self = _class_0 + self.is_base_type = function(self, val) + do + local mt = type(val) == "table" and getmetatable(val) + if mt then + if mt.__class then + return mt.__class.is_base_type == BaseType.is_base_type + end + end + end + return false + end + self.__inherited = function(self, cls) + cls.__base.__call = cls.__call + cls.__base.__div = self.__div + cls.__base.__mod = self.__mod + cls.__base.__mul = self.__mul + cls.__base.__add = self.__add + cls.__base.__unm = self.__unm + cls.__base.__tostring = self.__tostring + end + BaseType = _class_0 +end +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + return self.node:_describe() + end, + _transform = function(self, value, state) + local state_or_err + value, state_or_err = self.node:_transform(value, state) + if value == FailedTransform then + return FailedTransform, state_or_err + else + local out + local _exp_0 = type(self.t_fn) + if "function" == _exp_0 then + if self.with_state then + out = self.t_fn(value, state_or_err) + else + out = self.t_fn(value) + end + else + out = self.t_fn + end + return out, state_or_err + end + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, node, t_fn) + self.node, self.t_fn = node, t_fn + return assert(self.node, "missing node for transform") + end, + __base = _base_0, + __name = "TransformNode", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + TransformNode = _class_0 +end +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + local item_names + do + local _accum_0 = { } + local _len_0 = 1 + local _list_0 = self.sequence + for _index_0 = 1, #_list_0 do + local i = _list_0[_index_0] + _accum_0[_len_0] = describe_type(i) + _len_0 = _len_0 + 1 + end + item_names = _accum_0 + end + return join_names(item_names, " then ") + end, + _transform = function(self, value, state) + local _list_0 = self.sequence + for _index_0 = 1, #_list_0 do + local node = _list_0[_index_0] + value, state = node:_transform(value, state) + if value == FailedTransform then + break + end + end + return value, state + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, ...) + self.sequence = { + ... + } + end, + __base = _base_0, + __name = "SequenceNode", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + SequenceNode = _class_0 +end +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + local item_names + do + local _accum_0 = { } + local _len_0 = 1 + local _list_0 = self.options + for _index_0 = 1, #_list_0 do + local i = _list_0[_index_0] + _accum_0[_len_0] = describe_type(i) + _len_0 = _len_0 + 1 + end + item_names = _accum_0 + end + return join_names(item_names, ", ", ", or ") + end, + _transform = function(self, value, state) + if not (self.options[1]) then + return FailedTransform, "no options for node" + end + local _list_0 = self.options + for _index_0 = 1, #_list_0 do + local node = _list_0[_index_0] + local new_val, new_state = node:_transform(value, state) + if not (new_val == FailedTransform) then + return new_val, new_state + end + end + return FailedTransform, "expected " .. tostring(self:_describe()) + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, ...) + self.options = { + ... + } + end, + __base = _base_0, + __name = "FirstOfNode", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + FirstOfNode = _class_0 +end +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _transform = function(self, input, ...) + local value, state = self.node:_transform(input, ...) + if value == FailedTransform then + local err + if self.err_handler then + err = self.err_handler(input, state) + else + err = "expected " .. tostring(self:_describe()) + end + return FailedTransform, err + end + return value, state + end, + describe = function(self, ...) + return DescribeNode(self.node, ...) + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, node, describe) + self.node = node + local err_message + if type(describe) == "table" then + describe, err_message = describe.type, describe.error + end + if type(describe) == "string" then + self._describe = function() + return describe + end + else + self._describe = describe + end + if err_message then + if type(err_message) == "string" then + self.err_handler = function() + return err_message + end + else + self.err_handler = err_message + end + end + end, + __base = _base_0, + __name = "DescribeNode", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + DescribeNode = _class_0 +end +local AnnotateNode +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + format_error = function(self, value, err) + return tostring(tostring(value)) .. ": " .. tostring(err) + end, + _transform = function(self, value, state) + local new_value, state_or_err = self.base_type:_transform(value, state) + if new_value == FailedTransform then + return FailedTransform, self:format_error(value, state_or_err) + else + return new_value, state_or_err + end + end, + _describe = function(self) + if self.base_type._describe then + return self.base_type:_describe() + end + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, base_type, opts) + self.base_type = assert(coerce_literal(base_type)) + if opts then + if opts.format_error then + self.format_error = assert(types.func:transform(opts.format_error)) + end + end + end, + __base = _base_0, + __name = "AnnotateNode", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + AnnotateNode = _class_0 +end +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + update_state = function(self, state, value, ...) + local out = clone_state(state) + if self.tag_type == "function" then + if select("#", ...) > 0 then + self.tag_name(out, ..., value) + else + self.tag_name(out, value) + end + else + if self.tag_array then + local existing = out[self.tag_name] + if type(existing) == "table" then + local copy + do + local _tbl_0 = { } + for k, v in pairs(existing) do + _tbl_0[k] = v + end + copy = _tbl_0 + end + table.insert(copy, value) + out[self.tag_name] = copy + else + out[self.tag_name] = { + value + } + end + else + out[self.tag_name] = value + end + end + return out + end, + _transform = function(self, value, state) + value, state = self.base_type:_transform(value, state) + if value == FailedTransform then + return FailedTransform, state + end + state = self:update_state(state, value) + return value, state + end, + _describe = function(self) + local base_description = self.base_type:_describe() + return tostring(base_description) .. " tagged " .. tostring(describe_type(self.tag_name)) + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, base_type, opts) + if opts == nil then + opts = { } + end + self.base_type = base_type + self.tag_name = assert(opts.tag, "tagged type missing tag") + self.tag_type = type(self.tag_name) + if self.tag_type == "string" then + if self.tag_name:match("%[%]$") then + self.tag_name = self.tag_name:sub(1, -3) + self.tag_array = true + end + end + end, + __base = _base_0, + __name = "TaggedType", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + TaggedType = _class_0 +end +local TagScopeType +do + local _class_0 + local _parent_0 = TaggedType + local _base_0 = { + create_scope_state = function(self, state) + return nil + end, + _transform = function(self, value, state) + local scope + value, scope = self.base_type:_transform(value, self:create_scope_state(state)) + if value == FailedTransform then + return FailedTransform, scope + end + if self.tag_name then + state = self:update_state(state, scope, value) + end + return value, state + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, base_type, opts) + if opts then + return _class_0.__parent.__init(self, base_type, opts) + else + self.base_type = base_type + end + end, + __base = _base_0, + __name = "TagScopeType", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + TagScopeType = _class_0 +end +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _transform = function(self, value, state) + if value == nil then + return value, state + end + return self.base_type:_transform(value, state) + end, + is_optional = function(self) + return self + end, + _describe = function(self) + if self.base_type._describe then + local base_description = self.base_type:_describe() + return "optional " .. tostring(base_description) + end + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, base_type) + self.base_type = base_type + return assert(BaseType:is_base_type(self.base_type), "expected a type checker") + end, + __base = _base_0, + __name = "OptionalType", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + OptionalType = _class_0 +end +local AnyType +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _transform = function(self, v, state) + return v, state + end, + _describe = function(self) + return "anything" + end, + is_optional = function(self) + return self + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, ...) + return _class_0.__parent.__init(self, ...) + end, + __base = _base_0, + __name = "AnyType", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + AnyType = _class_0 +end +local Type +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _transform = function(self, value, state) + local got = type(value) + if self.t ~= got then + return FailedTransform, "expected type " .. tostring(describe_type(self.t)) .. ", got " .. tostring(describe_type(got)) + end + if self.length_type then + local len = #value + local res + res, state = self.length_type:_transform(len, state) + if res == FailedTransform then + return FailedTransform, tostring(self.t) .. " length " .. tostring(state) .. ", got " .. tostring(len) + end + end + return value, state + end, + length = function(self, left, right) + local l + if BaseType:is_base_type(left) then + l = left + else + l = types.range(left, right) + end + return Type(self.t, { + length = l + }) + end, + _describe = function(self) + local t = "type " .. tostring(describe_type(self.t)) + if self.length_type then + t = t .. " length_type " .. tostring(self.length_type:_describe()) + end + return t + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, t, opts) + self.t = t + if opts then + if opts.length then + self.length_type = assert(coerce_literal(opts.length)) + end + end + end, + __base = _base_0, + __name = "Type", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + Type = _class_0 +end +local ArrayType +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + return "an array" + end, + _transform = function(self, value, state) + if not (type(value) == "table") then + return FailedTransform, "expecting table" + end + local k = 1 + for i, v in pairs(value) do + if not (type(i) == "number") then + return FailedTransform, "non number field: " .. tostring(i) + end + if not (i == k) then + return FailedTransform, "non array index, got " .. tostring(describe_type(i)) .. " but expected " .. tostring(describe_type(k)) + end + k = k + 1 + end + return value, state + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, ...) + return _class_0.__parent.__init(self, ...) + end, + __base = _base_0, + __name = "ArrayType", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + ArrayType = _class_0 +end +local OneOf +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + local item_names + do + local _accum_0 = { } + local _len_0 = 1 + local _list_0 = self.options + for _index_0 = 1, #_list_0 do + local i = _list_0[_index_0] + if type(i) == "table" and i._describe then + _accum_0[_len_0] = i:_describe() + else + _accum_0[_len_0] = describe_type(i) + end + _len_0 = _len_0 + 1 + end + item_names = _accum_0 + end + return tostring(join_names(item_names, ", ", ", or ")) + end, + _transform = function(self, value, state) + if self.options_hash then + if self.options_hash[value] then + return value, state + end + else + local _list_0 = self.options + for _index_0 = 1, #_list_0 do + local _continue_0 = false + repeat + local item = _list_0[_index_0] + if item == value then + return value, state + end + if BaseType:is_base_type(item) then + local new_value, new_state = item:_transform(value, state) + if new_value == FailedTransform then + _continue_0 = true + break + end + return new_value, new_state + end + _continue_0 = true + until true + if not _continue_0 then + break + end + end + end + return FailedTransform, "expected " .. tostring(self:_describe()) + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, options) + self.options = options + assert(type(self.options) == "table", "expected table for options in one_of") + local fast_opts = types.array_of(types.number + types.string) + if fast_opts(self.options) then + do + local _tbl_0 = { } + local _list_0 = self.options + for _index_0 = 1, #_list_0 do + local v = _list_0[_index_0] + _tbl_0[v] = true + end + self.options_hash = _tbl_0 + end + end + end, + __base = _base_0, + __name = "OneOf", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + OneOf = _class_0 +end +local AllOf +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + local item_names + do + local _accum_0 = { } + local _len_0 = 1 + local _list_0 = self.types + for _index_0 = 1, #_list_0 do + local i = _list_0[_index_0] + _accum_0[_len_0] = describe_type(i) + _len_0 = _len_0 + 1 + end + item_names = _accum_0 + end + return join_names(item_names, " and ") + end, + _transform = function(self, value, state) + local _list_0 = self.types + for _index_0 = 1, #_list_0 do + local t = _list_0[_index_0] + value, state = t:_transform(value, state) + if value == FailedTransform then + return FailedTransform, state + end + end + return value, state + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, types) + self.types = types + assert(type(self.types) == "table", "expected table for first argument") + local _list_0 = self.types + for _index_0 = 1, #_list_0 do + local checker = _list_0[_index_0] + assert(BaseType:is_base_type(checker), "all_of expects all type checkers") + end + end, + __base = _base_0, + __name = "AllOf", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + AllOf = _class_0 +end +local ArrayOf +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + return "array of " .. tostring(describe_type(self.expected)) + end, + _transform = function(self, value, state) + local pass, err = types.table(value) + if not (pass) then + return FailedTransform, err + end + if self.length_type then + local len = #value + local res + res, state = self.length_type:_transform(len, state) + if res == FailedTransform then + return FailedTransform, "array length " .. tostring(state) .. ", got " .. tostring(len) + end + end + local is_literal = not BaseType:is_base_type(self.expected) + local copy, k + for idx, item in ipairs(value) do + local skip_item = false + local transformed_item + if is_literal then + if self.expected ~= item then + return FailedTransform, "array item " .. tostring(idx) .. ": expected " .. tostring(describe_type(self.expected)) + else + transformed_item = item + end + else + local item_val + item_val, state = self.expected:_transform(item, state) + if item_val == FailedTransform then + return FailedTransform, "array item " .. tostring(idx) .. ": " .. tostring(state) + end + if item_val == nil and not self.keep_nils then + skip_item = true + else + transformed_item = item_val + end + end + if transformed_item ~= item or skip_item then + if not (copy) then + do + local _accum_0 = { } + local _len_0 = 1 + local _max_0 = idx - 1 + for _index_0 = 1, _max_0 < 0 and #value + _max_0 or _max_0 do + local i = value[_index_0] + _accum_0[_len_0] = i + _len_0 = _len_0 + 1 + end + copy = _accum_0 + end + k = idx + end + end + if copy and not skip_item then + copy[k] = transformed_item + k = k + 1 + end + end + return copy or value, state + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, expected, opts) + self.expected = expected + if opts then + self.keep_nils = opts.keep_nils and true + if opts.length then + self.length_type = assert(coerce_literal(opts.length)) + end + end + end, + __base = _base_0, + __name = "ArrayOf", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + local self = _class_0 + self.type_err_message = "expecting table" + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + ArrayOf = _class_0 +end +local ArrayContains +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + short_circuit = true, + keep_nils = false, + _describe = function(self) + return "array containing " .. tostring(describe_type(self.contains)) + end, + _transform = function(self, value, state) + local pass, err = types.table(value) + if not (pass) then + return FailedTransform, err + end + local is_literal = not BaseType:is_base_type(self.contains) + local contains = false + local copy, k + for idx, item in ipairs(value) do + local skip_item = false + local transformed_item + if is_literal then + if self.contains == item then + contains = true + end + transformed_item = item + else + local item_val, new_state = self.contains:_transform(item, state) + if item_val == FailedTransform then + transformed_item = item + else + state = new_state + contains = true + if item_val == nil and not self.keep_nils then + skip_item = true + else + transformed_item = item_val + end + end + end + if transformed_item ~= item or skip_item then + if not (copy) then + do + local _accum_0 = { } + local _len_0 = 1 + local _max_0 = idx - 1 + for _index_0 = 1, _max_0 < 0 and #value + _max_0 or _max_0 do + local i = value[_index_0] + _accum_0[_len_0] = i + _len_0 = _len_0 + 1 + end + copy = _accum_0 + end + k = idx + end + end + if copy and not skip_item then + copy[k] = transformed_item + k = k + 1 + end + if contains and self.short_circuit then + if copy then + for kdx = idx + 1, #value do + copy[k] = value[kdx] + k = k + 1 + end + end + break + end + end + if not (contains) then + return FailedTransform, "expected " .. tostring(self:_describe()) + end + return copy or value, state + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, contains, opts) + self.contains = contains + assert(self.contains, "missing contains") + if opts then + self.short_circuit = opts.short_circuit and true + self.keep_nils = opts.keep_nils and true + end + end, + __base = _base_0, + __name = "ArrayContains", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + local self = _class_0 + self.type_err_message = "expecting table" + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + ArrayContains = _class_0 +end +local MapOf +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + return "map of " .. tostring(self.expected_key:_describe()) .. " -> " .. tostring(self.expected_value:_describe()) + end, + _transform = function(self, value, state) + local pass, err = types.table(value) + if not (pass) then + return FailedTransform, err + end + local key_literal = not BaseType:is_base_type(self.expected_key) + local value_literal = not BaseType:is_base_type(self.expected_value) + local transformed = false + local out = { } + for k, v in pairs(value) do + local _continue_0 = false + repeat + local new_k = k + local new_v = v + if key_literal then + if k ~= self.expected_key then + return FailedTransform, "map key expected " .. tostring(describe_type(self.expected_key)) + end + else + new_k, state = self.expected_key:_transform(k, state) + if new_k == FailedTransform then + return FailedTransform, "map key " .. tostring(state) + end + end + if value_literal then + if v ~= self.expected_value then + return FailedTransform, "map value expected " .. tostring(describe_type(self.expected_value)) + end + else + new_v, state = self.expected_value:_transform(v, state) + if new_v == FailedTransform then + return FailedTransform, "map value " .. tostring(state) + end + end + if new_k ~= k or new_v ~= v then + transformed = true + end + if new_k == nil then + _continue_0 = true + break + end + out[new_k] = new_v + _continue_0 = true + until true + if not _continue_0 then + break + end + end + return transformed and out or value, state + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, expected_key, expected_value) + self.expected_key = coerce_literal(expected_key) + self.expected_value = coerce_literal(expected_value) + end, + __base = _base_0, + __name = "MapOf", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + MapOf = _class_0 +end +local Shape +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + open = false, + check_all = false, + is_open = function(self) + return Shape(self.shape, { + open = true, + check_all = self.check_all or nil + }) + end, + _describe = function(self) + local parts + do + local _accum_0 = { } + local _len_0 = 1 + for k, v in pairs(self.shape) do + _accum_0[_len_0] = tostring(describe_type(k)) .. " = " .. tostring(describe_type(v)) + _len_0 = _len_0 + 1 + end + parts = _accum_0 + end + return "{ " .. tostring(table.concat(parts, ", ")) .. " }" + end, + _transform = function(self, value, state) + local pass, err = types.table(value) + if not (pass) then + return FailedTransform, err + end + local check_all = self.check_all + local remaining_keys + do + local _tbl_0 = { } + for key in pairs(value) do + _tbl_0[key] = true + end + remaining_keys = _tbl_0 + end + local errors + local dirty = false + local out = { } + for shape_key, shape_val in pairs(self.shape) do + local item_value = value[shape_key] + if remaining_keys then + remaining_keys[shape_key] = nil + end + local new_val + if BaseType:is_base_type(shape_val) then + new_val, state = shape_val:_transform(item_value, state) + else + if shape_val == item_value then + new_val, state = item_value, state + else + new_val, state = FailedTransform, "expected " .. tostring(describe_type(shape_val)) + end + end + if new_val == FailedTransform then + err = "field " .. tostring(describe_type(shape_key)) .. ": " .. tostring(state) + if check_all then + if errors then + table.insert(errors, err) + else + errors = { + err + } + end + else + return FailedTransform, err + end + else + if new_val ~= item_value then + dirty = true + end + out[shape_key] = new_val + end + end + if remaining_keys and next(remaining_keys) then + if self.open then + for k in pairs(remaining_keys) do + out[k] = value[k] + end + elseif self.extra_fields_type then + for k in pairs(remaining_keys) do + local item_value = value[k] + local tuple + tuple, state = self.extra_fields_type:_transform({ + [k] = item_value + }, state) + if tuple == FailedTransform then + err = "field " .. tostring(describe_type(k)) .. ": " .. tostring(state) + if check_all then + if errors then + table.insert(errors, err) + else + errors = { + err + } + end + else + return FailedTransform, err + end + else + do + local nk = tuple and next(tuple) + if nk then + if nk ~= k then + dirty = true + elseif tuple[nk] ~= item_value then + dirty = true + end + out[nk] = tuple[nk] + else + dirty = true + end + end + end + end + else + local names + do + local _accum_0 = { } + local _len_0 = 1 + for key in pairs(remaining_keys) do + _accum_0[_len_0] = describe_type(key) + _len_0 = _len_0 + 1 + end + names = _accum_0 + end + err = "extra fields: " .. tostring(table.concat(names, ", ")) + if check_all then + if errors then + table.insert(errors, err) + else + errors = { + err + } + end + else + return FailedTransform, err + end + end + end + if errors and next(errors) then + return FailedTransform, table.concat(errors, "; ") + end + return dirty and out or value, state + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, shape, opts) + self.shape = shape + assert(type(self.shape) == "table", "expected table for shape") + if opts then + if opts.extra_fields then + assert(BaseType:is_base_type(opts.extra_fields), "extra_fields_type must be type checker") + self.extra_fields_type = opts.extra_fields + end + self.open = opts.open and true + self.check_all = opts.check_all and true + if self.open then + assert(not self.extra_fields_type, "open can not be combined with extra_fields") + end + if self.extra_fields_type then + return assert(not self.open, "extra_fields can not be combined with open") + end + end + end, + __base = _base_0, + __name = "Shape", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + local self = _class_0 + self.type_err_message = "expecting table" + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + Shape = _class_0 +end +local Partial +do + local _class_0 + local _parent_0 = Shape + local _base_0 = { + open = true, + is_open = function(self) + return error("is_open has no effect on Partial") + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, ...) + return _class_0.__parent.__init(self, ...) + end, + __base = _base_0, + __name = "Partial", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + Partial = _class_0 +end +local Pattern +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + return "pattern " .. tostring(describe_type(self.pattern)) + end, + _transform = function(self, value, state) + local test_value + if self.coerce then + if BaseType:is_base_type(self.coerce) then + local c_res, err = self.coerce:_transform(value) + if c_res == FailedTransform then + return FailedTransform, err + end + test_value = c_res + else + test_value = tostring(value) + end + else + test_value = value + end + local t_res, err = types.string(test_value) + if not (t_res) then + return FailedTransform, err + end + if test_value:match(self.pattern) then + return value, state + else + return FailedTransform, "doesn't match " .. tostring(self:_describe()) + end + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, pattern, opts) + self.pattern = pattern + assert(type(self.pattern) == "string", "Pattern must be a string") + if opts then + self.coerce = opts.coerce + return assert(opts.initial_type == nil, "initial_type has been removed from types.pattern (got: " .. tostring(opts.initial_type) .. ")") + end + end, + __base = _base_0, + __name = "Pattern", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + Pattern = _class_0 +end +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + return describe_type(self.value) + end, + _transform = function(self, value, state) + if self.value ~= value then + return FailedTransform, "expected " .. tostring(self:_describe()) + end + return value, state + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, value) + self.value = value + end, + __base = _base_0, + __name = "Literal", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + Literal = _class_0 +end +local Custom +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + return "custom checker " .. tostring(self.fn) + end, + _transform = function(self, value, state) + local pass, err = self.fn(value, state) + if not (pass) then + return FailedTransform, err or "failed custom check" + end + return value, state + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, fn) + self.fn = fn + return assert(type(self.fn) == "function", "custom checker must be a function") + end, + __base = _base_0, + __name = "Custom", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + Custom = _class_0 +end +local Equivalent +do + local _class_0 + local values_equivalent + local _parent_0 = BaseType + local _base_0 = { + _describe = function(self) + return "equivalent to " .. tostring(describe_type(self.val)) + end, + _transform = function(self, value, state) + if values_equivalent(self.val, value) then + return value, state + else + return FailedTransform, "not equivalent to " .. tostring(self.val) + end + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, val) + self.val = val + end, + __base = _base_0, + __name = "Equivalent", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + local self = _class_0 + values_equivalent = function(a, b) + if a == b then + return true + end + if type(a) == "table" and type(b) == "table" then + local seen_keys = { } + for k, v in pairs(a) do + seen_keys[k] = true + if not (values_equivalent(v, b[k])) then + return false + end + end + for k, v in pairs(b) do + local _continue_0 = false + repeat + if seen_keys[k] then + _continue_0 = true + break + end + if not (values_equivalent(v, a[k])) then + return false + end + _continue_0 = true + until true + if not _continue_0 then + break + end + end + return true + else + return false + end + end + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + Equivalent = _class_0 +end +local Range +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _transform = function(self, value, state) + local res + res, state = self.value_type:_transform(value, state) + if res == FailedTransform then + return FailedTransform, "range " .. tostring(state) + end + if value < self.left then + return FailedTransform, "not in " .. tostring(self:_describe()) + end + if value > self.right then + return FailedTransform, "not in " .. tostring(self:_describe()) + end + return value, state + end, + _describe = function(self) + return "range from " .. tostring(self.left) .. " to " .. tostring(self.right) + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, left, right) + self.left, self.right = left, right + assert(self.left <= self.right, "left range value should be less than right range value") + self.value_type = assert(types[type(self.left)], "couldn't figure out type of range boundary") + end, + __base = _base_0, + __name = "Range", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + Range = _class_0 +end +local Proxy +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _transform = function(self, ...) + return assert(self.fn(), "proxy missing transformer"):_transform(...) + end, + _describe = function(self, ...) + return assert(self.fn(), "proxy missing transformer"):_describe(...) + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, fn) + self.fn = fn + end, + __base = _base_0, + __name = "Proxy", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + Proxy = _class_0 +end +local AssertType +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + assert = assert, + _transform = function(self, value, state) + local state_or_err + value, state_or_err = self.base_type:_transform(value, state) + self.assert(value ~= FailedTransform, state_or_err) + return value, state_or_err + end, + _describe = function(self) + if self.base_type._describe then + local base_description = self.base_type:_describe() + return "assert " .. tostring(base_description) + end + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, base_type) + self.base_type = base_type + return assert(BaseType:is_base_type(self.base_type), "expected a type checker") + end, + __base = _base_0, + __name = "AssertType", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + AssertType = _class_0 +end +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _transform = function(self, value, state) + local out, _ = self.base_type:_transform(value, state) + if out == FailedTransform then + return value, state + else + return FailedTransform, "expected " .. tostring(self:_describe()) + end + end, + _describe = function(self) + if self.base_type._describe then + local base_description = self.base_type:_describe() + return "not " .. tostring(base_description) + end + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, base_type) + self.base_type = base_type + return assert(BaseType:is_base_type(self.base_type), "expected a type checker") + end, + __base = _base_0, + __name = "NotType", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + NotType = _class_0 +end +local CloneType +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + _transform = function(self, value, state) + local _exp_0 = type(value) + if "nil" == _exp_0 or "string" == _exp_0 or "number" == _exp_0 or "boolean" == _exp_0 then + return value, state + elseif "table" == _exp_0 then + local clone_value + do + local _tbl_0 = { } + for k, v in pairs(value) do + _tbl_0[k] = v + end + clone_value = _tbl_0 + end + do + local mt = getmetatable(value) + if mt then + setmetatable(clone_value, mt) + end + end + return clone_value, state + else + return FailedTransform, tostring(describe_type(value)) .. " is not cloneable" + end + end, + _describe = function(self) + return "cloneable value" + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, ...) + return _class_0.__parent.__init(self, ...) + end, + __base = _base_0, + __name = "CloneType", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + CloneType = _class_0 +end +local MetatableIsType +do + local _class_0 + local _parent_0 = BaseType + local _base_0 = { + allow_metatable_update = false, + _transform = function(self, value, state) + local state_or_err + value, state_or_err = types.table:_transform(value, state) + if value == FailedTransform then + return FailedTransform, state_or_err + end + local mt = getmetatable(value) + local new_mt + new_mt, state_or_err = self.metatable_type:_transform(mt, state_or_err) + if new_mt == FailedTransform then + return FailedTransform, "metatable expected: " .. tostring(state_or_err) + end + if new_mt ~= mt then + if self.allow_metatable_update then + setmetatable(value, new_mt) + else + return FailedTransform, "metatable was modified by a type but { allow_metatable_update = true } is not enabled" + end + end + return value, state_or_err + end, + _describe = function(self) + return "has metatable " .. tostring(describe_type(self.metatable_type)) + end + } + _base_0.__index = _base_0 + setmetatable(_base_0, _parent_0.__base) + _class_0 = setmetatable({ + __init = function(self, metatable_type, opts) + if BaseType:is_base_type(metatable_type) then + self.metatable_type = metatable_type + else + self.metatable_type = Literal(metatable_type) + end + if opts then + self.allow_metatable_update = opts.allow_metatable_update and true + end + end, + __base = _base_0, + __name = "MetatableIsType", + __parent = _parent_0 + }, { + __index = function(cls, name) + local val = rawget(_base_0, name) + if val == nil then + local parent = rawget(cls, "__parent") + if parent then + return parent[name] + end + else + return val + end + end, + __call = function(cls, ...) + local _self_0 = setmetatable({}, _base_0) + cls.__init(_self_0, ...) + return _self_0 + end + }) + _base_0.__class = _class_0 + if _parent_0.__inherited then + _parent_0.__inherited(_parent_0, _class_0) + end + MetatableIsType = _class_0 +end +local type_nil = Type("nil") +local type_function = Type("function") +local type_number = Type("number") +types = setmetatable({ + any = AnyType(), + string = Type("string"), + number = type_number, + ["function"] = type_function, + func = type_function, + boolean = Type("boolean"), + userdata = Type("userdata"), + ["nil"] = type_nil, + null = type_nil, + table = Type("table"), + array = ArrayType(), + clone = CloneType(), + integer = Pattern("^%d+$", { + coerce = type_number / tostring + }), + one_of = OneOf, + all_of = AllOf, + shape = Shape, + partial = Partial, + pattern = Pattern, + array_of = ArrayOf, + array_contains = ArrayContains, + map_of = MapOf, + literal = Literal, + range = Range, + equivalent = Equivalent, + custom = Custom, + scope = TagScopeType, + proxy = Proxy, + assert = AssertType, + annotate = AnnotateNode, + metatable_is = MetatableIsType +}, { + __index = function(self, fn_name) + return error("Type checker does not exist: `" .. tostring(fn_name) .. "`") + end +}) +local check_shape +check_shape = function(value, shape) + assert(shape.check_value, "missing check_value method from shape") + return shape:check_value(value) +end +is_type = function(val) + return BaseType:is_base_type(val) +end +return { + check_shape = check_shape, + types = types, + is_type = is_type, + BaseType = BaseType, + FailedTransform = FailedTransform, + VERSION = "2.6.0" +} diff --git a/engine/lua/typecheck.lua b/engine/lua/typecheck.lua new file mode 100644 index 000000000..88e9b2ded --- /dev/null +++ b/engine/lua/typecheck.lua @@ -0,0 +1,1577 @@ +--[[ + Source: https://github.com/gvvaughan/typecheck/blob/v3.0/lib/typecheck/init.lua + Modifications: + - Rename typecheck/init.lua -> typecheck.lua. + - Embed LICENSE.md into this file. + - Add this comment. +]] + +--[[ + Gradual Function Type Checking for Lua 5.1, 5.2, 5.3 & 5.4 + Copyright (C) 2014-2023 Gary V. Vaughan + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation + files (the "Software"), to deal in the Software without restriction, + including without limitation the rights to use, copy, modify, merge, + publish, distribute, sublicense, and/or sell copies of the Software, + and to permit persons to whom the Software is furnished to do so, + subject to the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGE- + MENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE + FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF + CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +]] + +--[[-- + Gradual type checking for Lua functions. + + The behaviour of the functions in this module are controlled by the value + of the `argcheck` field maintained by the `std._debug` module. Not setting + a value prior to loading this module is equivalent to having `argcheck = true`. + + The first line of Lua code in production quality applications that value + execution speed over rigorous function type checking should be: + + require 'std._debug' (false) + + Alternatively, if your project also depends on other `std._debug` hints + remaining enabled: + + require 'std._debug'.argcheck = false + + This mitigates almost all of the overhead of type checking with the + functions from this module. + + @module typecheck +]] + + + +--[[ ====================== ]]-- +--[[ Load optional modules. ]]-- +--[[ ====================== ]]-- + + +local _debug = (function() + local ok, r = pcall(require, 'std._debug') + if not ok then + r = setmetatable({ + -- If this module was required, but there's no std._debug, safe to + -- assume we do want runtime argchecks! + argcheck = true, + -- Similarly, if std.strict is available, but there's no _std.debug, + -- then apply strict global symbol checks to this module! + strict = true, + }, { + __call = function(self, x) + self.argscheck = (x ~= false) + end, + }) + end + + return r +end)() + + +local strict = (function() + local setfenv = rawget(_G, 'setfenv') or function() end + + -- No strict global symbol checks with no std.strict module, even + -- if we found std._debug and requested that! + local r = function(env, level) + setfenv(1+(level or 1), env) + return env + end + + if _debug.strict then + -- Specify `.init` submodule to make sure we only accept + -- lua-stdlib/strict, and not the old strict module from + -- lua-stdlib/lua-stdlib. + local ok, m = pcall(require, 'std.strict.init') + if ok then + r = m + end + end + return r +end)() + + +local _ENV = strict(_G) + + + +--[[ ================== ]]-- +--[[ Lua normalization. ]]-- +--[[ ================== ]]-- + + +local concat = table.concat +local find = string.find +local floor = math.floor +local format = string.format +local gsub = string.gsub +local insert = table.insert +local io_type = io.type +local match = string.match +local remove = table.remove +local sort = table.sort +local sub = string.sub + + +-- Return callable objects. +-- @function callable +-- @param x an object or primitive +-- @return *x* if *x* can be called, otherwise `nil` +-- @usage +-- (callable(functable) or function()end)(args, ...) +local function callable(x) + -- Careful here! + -- Most versions of Lua don't recurse functables, so make sure you + -- always put a real function in __call metamethods. Consequently, + -- no reason to recurse here. + -- func=function() print 'called' end + -- func() --> 'called' + -- functable=setmetatable({}, {__call=func}) + -- functable() --> 'called' + -- nested=setmetatable({}, {__call=function(self, ...) return functable(...)end}) + -- nested() -> 'called' + -- notnested=setmetatable({}, {__call=functable}) + -- notnested() + -- --> stdin:1: attempt to call global 'nested' (a table value) + -- --> stack traceback: + -- --> stdin:1: in main chunk + -- --> [C]: in ? + if type(x) == 'function' or (getmetatable(x) or {}).__call then + return x + end +end + + +-- Return named metamethod, if callable, otherwise `nil`. +-- @param x item to act on +-- @string n name of metamethod to look up +-- @treturn function|nil metamethod function, if callable, otherwise `nil` +local function getmetamethod(x, n) + return callable((getmetatable(x) or {})[n]) +end + + +-- Length of a string or table object without using any metamethod. +-- @function rawlen +-- @tparam string|table x object to act on +-- @treturn int raw length of *x* +-- @usage +-- --> 0 +-- rawlen(setmetatable({}, {__len=function() return 42})) +local function rawlen(x) + -- Lua 5.1 does not implement rawlen, and while # operator ignores + -- __len metamethod, `nil` in sequence is handled inconsistently. + if type(x) ~= 'table' then + return #x + end + + local n = #x + for i = 1, n do + if x[i] == nil then + return i -1 + end + end + return n +end + + +-- Deterministic, functional version of core Lua `#` operator. +-- +-- Respects `__len` metamethod (like Lua 5.2+). Otherwise, always return +-- one less than the lowest integer index with a `nil` value in *x*, where +-- the `#` operator implementation might return the size of the array part +-- of a table. +-- @function len +-- @param x item to act on +-- @treturn int the length of *x* +-- @usage +-- x = {1, 2, 3, nil, 5} +-- --> 5 3 +-- print(#x, len(x)) +local function len(x) + return (getmetamethod(x, '__len') or rawlen)(x) +end + + +-- Return a list of given arguments, with field `n` set to the length. +-- +-- The returned table also has a `__len` metamethod that returns `n`, so +-- `ipairs` and `unpack` behave sanely when there are `nil` valued elements. +-- @function pack +-- @param ... tuple to act on +-- @treturn table packed list of *...* values, with field `n` set to +-- number of tuple elements (including any explicit `nil` elements) +-- @see unpack +-- @usage +-- --> 5 +-- len(pack(nil, 2, 5, nil, nil)) +local pack = (function(f) + local pack_mt = { + __len = function(self) + return self.n + end, + } + + local pack_fn = f or function(...) + return {n=select('#', ...), ...} + end + + return function(...) + return setmetatable(pack_fn(...), pack_mt) + end +end)(rawget(_G, "pack")) + + +-- Like Lua `pairs` iterator, but respect `__pairs` even in Lua 5.1. +-- @function pairs +-- @tparam table t table to act on +-- @treturn function iterator function +-- @treturn table *t*, the table being iterated over +-- @return the previous iteration key +-- @usage +-- for k, v in pairs {'a', b='c', foo=42} do process(k, v) end +local pairs = (function(f) + if not f(setmetatable({},{__pairs=function() return false end})) then + return f + end + + return function(t) + return(getmetamethod(t, '__pairs') or f)(t) + end +end)(pairs) + + +-- Convert a number to an integer and return if possible, otherwise `nil`. +-- @function math.tointeger +-- @param x object to act on +-- @treturn[1] integer *x* converted to an integer if possible +-- @return[2] `nil` otherwise +local tointeger = (function(f) + if f == nil then + -- No host tointeger implementationm use our own. + return function(x) + if type(x) == 'number' and x - floor(x) == 0.0 then + return x + end + end + + elseif f '1' ~= nil then + -- Don't perform implicit string-to-number conversion! + return function(x) + if type(x) == 'number' then + return f(x) + end + end + end + + -- Host tointeger is good! + return f +end)(math.tointeger) + + +-- Return 'integer', 'float' or `nil` according to argument type. +-- +-- To ensure the same behaviour on all host Lua implementations, +-- this function returns 'float' for integer-equivalent floating +-- values, even on Lua 5.3. +-- @function math.type +-- @param x object to act on +-- @treturn[1] string 'integer', if *x* is a whole number +-- @treturn[2] string 'float', for other numbers +-- @return[3] `nil` otherwise +local math_type = math.type or function(x) + if type(x) == 'number' then + return tointeger(x) and 'integer' or 'float' + end +end + + +-- Get a function or functable environment. +-- +-- This version of getfenv works on all supported Lua versions, and +-- knows how to unwrap functables. +-- @function getfenv +-- @tparam function|int fn stack level, C or Lua function or functable +-- to act on +-- @treturn table the execution environment of *fn* +-- @usage +-- callers_environment = getfenv(1) +local getfenv = (function(f) + local debug_getfenv = debug.getfenv + local debug_getinfo = debug.getinfo + local debug_getupvalue = debug.getupvalue + + if debug_getfenv then + + return function(fn) + local n = tointeger(fn or 1) + if n then + if n > 0 then + -- Adjust for this function's stack frame, if fn is non-zero. + n = n + 1 + end + + -- Return an additional nil result to defeat tail call elimination + -- which would remove a stack frame and break numeric *fn* count. + return f(n), nil + end + + if type(fn) ~= 'function' then + -- Unwrap functables: + -- No need to recurse because Lua doesn't support nested functables. + -- __call can only (sensibly) be a function, so no need to adjust + -- stack frame offset either. + fn =(getmetatable(fn) or {}).__call or fn + end + + -- In Lua 5.1, only debug.getfenv works on C functions; but it + -- does not work on stack counts. + return debug_getfenv(fn) + end + + else + + -- Thanks to http://lua-users.org/lists/lua-l/2010-06/msg00313.html + return function(fn) + if fn == 0 then + return _G + end + local n = tointeger(fn or 1) + if n then + fn = debug_getinfo(n + 1, 'f').func + elseif type(fn) ~= 'function' then + fn = (getmetatable(fn) or {}).__call or fn + end + + local name, env + local up = 0 + repeat + up = up + 1 + name, env = debug_getupvalue(fn, up) + until name == '_ENV' or name == nil + return env + end + + end +end)(rawget(_G, 'getfenv')) + + +-- Set a function or functable environment. +-- +-- This version of setfenv works on all supported Lua versions, and +-- knows how to unwrap functables. +-- @function setfenv +-- @tparam function|int fn stack level, C or Lua function or functable +-- to act on +-- @tparam table env new execution environment for *fn* +-- @treturn function function acted upon +-- @usage +-- function clearenv(fn) return setfenv(fn, {}) end +local setfenv = (function(f) + local debug_getinfo = debug.getinfo + local debug_getupvalue = debug.getupvalue + local debug_setfenv = debug.setfenv + local debug_setupvalue = debug.setupvalue + local debug_upvaluejoin = debug.upvaluejoin + + if debug_setfenv then + + return function(fn, env) + local n = tointeger(fn or 1) + if n then + if n > 0 then + n = n + 1 + end + return f(n, env), nil + end + if type(fn) ~= 'function' then + fn =(getmetatable(fn) or {}).__call or fn + end + return debug_setfenv(fn, env) + end + + else + + -- Thanks to http://lua-users.org/lists/lua-l/2010-06/msg00313.html + return function(fn, env) + local n = tointeger(fn or 1) + if n then + if n > 0 then + n = n + 1 + end + fn = debug_getinfo(n, 'f').func + elseif type(fn) ~= 'function' then + fn =(getmetatable(fn) or {}).__call or fn + end + + local up, name = 0, nil + repeat + up = up + 1 + name = debug_getupvalue(fn, up) + until name == '_ENV' or name == nil + if name then + debug_upvaluejoin(fn, up, function() return name end, 1) + debug_setupvalue(fn, up, env) + end + return n ~= 0 and fn or nil + end + + end +end)(rawget(_G, 'setfenv')) + + +-- Either `table.unpack` in newer-, or `unpack` in older Lua implementations. +-- Always defaulting to full packed table unpacking when no index arguments +-- are passed. +-- @function unpack +-- @tparam table t table to act on +-- @int[opt=1] i first index to unpack +-- @int[opt=len(t)] j last index to unpack +-- @return ... values of numeric indices of *t* +-- @see pack +-- @usage +-- local a, b, c = unpack(pack(nil, 2, nil)) +-- assert(a == nil and b == 2 and c == nil) +local unpack = (function(f) + return function(t, i, j) + return f(t, tointeger(i) or 1, tointeger(j) or len(t)) + end +end)(rawget(_G, "unpack") or table.unpack) + + + +--[[ ================= ]]-- +--[[ Helper Functions. ]]-- +--[[ ================= ]]-- + + +local function copy(dest, src) + if src == nil then + dest, src = {}, dest + end + for k, v in pairs(src) do + dest[k] = v + end + return dest +end + + +local function split(s, sep) + local r, pattern = {}, nil + if sep == '' then + pattern = '(.)' + r[#r + 1] = '' + else + pattern = '(.-)' ..(sep or '%s+') + end + local b, slen = 0, len(s) + while b <= slen do + local _, n, m = find(s, pattern, b + 1) + r[#r + 1] = m or sub(s, b + 1, slen) + b = n or slen + 1 + end + return r +end + + + +--[[ ================== ]]-- +--[[ Argument Checking. ]]-- +--[[ ================== ]]-- + + +-- There's an additional stack frame to count over from inside functions +-- with argchecks enabled. +local ARGCHECK_FRAME = 0 + + +local function argerror(name, i, extramsg, level) + level = tointeger(level) or 1 + local s = format("bad argument #%d to '%s'", tointeger(i), name) + if extramsg ~= nil then + s = s .. ' (' .. extramsg .. ')' + end + error(s, level > 0 and level + 2 + ARGCHECK_FRAME or 0) +end + + +-- A rudimentary argument type validation decorator. +-- +-- Return the checked function directly if `_debug.argcheck` is reset, +-- otherwise use check function arguments using predicate functions in +-- the corresponding position in the decorator call. +-- @function checktypes +-- @string name function name to use in error messages +-- @tparam funct predicate return true if checked function argument is +-- valid, otherwise return nil and an error message suitable for +-- *extramsg* argument of @{argerror} +-- @tparam func ... additional predicates for subsequent checked +-- function arguments +-- @raises argerror when an argument validator returns failure +-- @see argerror +-- @return function +-- @usage +-- local unpack = checktypes('unpack', types.table) .. +-- function(t, i, j) +-- return table.unpack(t, i or 1, j or #t) +-- end +local checktypes = (function() + -- Set checktypes according to whether argcheck was required by _debug. + if _debug.argcheck then + + ARGCHECK_FRAME = 1 + + local function icalls(checks, argu) + return function(state, i) + if i < state.checks.n then + i = i + 1 + local r = pack(state.checks[i](state.argu, i)) + if r.n > 0 then + return i, r[1], r[2] + end + return i + end + end, {argu=argu, checks=checks}, 0 + end + + return function(name, ...) + return setmetatable(pack(...), { + __concat = function(checks, inner) + if not callable(inner) then + error("attempt to annotate non-callable value with 'checktypes'", 2) + end + return function(...) + local argu = pack(...) + for i, expected, got in icalls(checks, argu) do + if got or expected then + local buf, extramsg = {}, nil + if expected then + got = got or ('got ' .. type(argu[i])) + buf[#buf +1] = expected .. ' expected, ' .. got + elseif got then + buf[#buf +1] = got + end + if #buf > 0 then + extramsg = concat(buf) + end + return argerror(name, i, extramsg, 3), nil + end + end + -- Tail call pessimisation: inner might be counting frames, + -- and have several return values that need preserving. + -- Different Lua implementations tail call under differing + -- conditions, so we need this hair to make sure we always + -- get the same number of stack frames interposed. + local results = pack(inner(...)) + return unpack(results, 1, results.n) + end + end, + }) + end + + else + + -- Return `inner` untouched, for no runtime overhead! + return function(...) + return setmetatable({}, { + __concat = function(_, inner) + return inner + end, + }) + end + + end +end)() + + +local function resulterror(name, i, extramsg, level) + level = level or 1 + local s = format("bad result #%d from '%s'", i, name) + if extramsg ~= nil then + s = s .. ' (' .. extramsg .. ')' + end + error(s, level > 0 and level + 1 + ARGCHECK_FRAME or 0) +end + + + +--[[ ================= ]]-- +--[[ Type annotations. ]]-- +--[[ ================= ]]-- + + +local function fail(expected, argu, i, got) + if i > argu.n then + return expected, 'got no value' + elseif got ~= nil then + return expected, 'got ' .. got + end + return expected +end + + +--- Low-level type conformance check helper. +-- +-- Use this, with a simple @{Predicate} function, to write concise argument +-- type check functions. +-- @function check +-- @string expected name of the expected type +-- @tparam table argu a packed table (including `n` field) of all arguments +-- @int i index into *argu* for argument to action +-- @tparam Predicate predicate check whether `argu[i]` matches `expected` +-- @usage +-- function callable(argu, i) +-- return check('string', argu, i, function(x) +-- return type(x) == 'string' +-- end) +-- end +local function check(expected, argu, i, predicate) + local arg = argu[i] + local ok, got = predicate(arg) + if not ok then + return fail(expected, argu, i, got) + end +end + + +local function _type(x) + return (getmetatable(x) or {})._type or io_type(x) or math_type(x) or type(x) +end + + +local types = setmetatable({ + -- Accept argu[i]. + accept = function() end, + + -- Reject missing argument *i*. + arg = function(argu, i) + if i > argu.n then + return 'no value' + end + end, + + -- Accept function valued or `__call` metamethod carrying argu[i]. + callable = function(argu, i) + return check('callable', argu, i, callable) + end, + + -- Accept argu[i] if it is an integer valued number + integer = function(argu, i) + local value = argu[i] + if type(tonumber(value)) ~= 'number' then + return fail('integer', argu, i) + end + if tointeger(value) == nil then + return nil, _type(value) .. ' has no integer representation' + end + end, + + -- Accept missing argument *i* (but not explicit `nil`). + missing = function(argu, i) + if i <= argu.n then + return nil + end + end, + + -- Accept non-nil valued argu[i]. + value = function(argu, i) + if i > argu.n then + return 'value', 'got no value' + elseif argu[i] == nil then + return 'value' + end + end, +}, { + __index = function(_, k) + -- Accept named primitive valued argu[i]. + return function(argu, i) + return check(k, argu, i, function(x) + return type(x) == k + end) + end + end, +}) + + +local function any(...) + local fns = {...} + return function(argu, i) + local buf = {} + local expected, got, r + for _, predicate in ipairs(fns) do + r = pack(predicate(argu, i)) + expected, got = r[1], r[2] + if r.n == 0 then + -- A match! + return + elseif r.n == 2 and expected == nil and #got > 0 then + -- Return non-type based mismatch immediately. + return expected, got + elseif expected ~= 'nil' then + -- Record one of the types we would have matched. + buf[#buf + 1] = expected + end + end + if #buf == 0 then + return got + elseif #buf > 1 then + sort(buf) + buf[#buf -1], buf[#buf] = buf[#buf -1] .. ' or ' .. buf[#buf], nil + end + expected = concat(buf, ', ') + if got ~= nil then + return expected, got + end + return expected + end +end + + +local function opt(...) + return any(types['nil'], ...) +end + + + +--[[ =============================== ]]-- +--[[ Implementation of value checks. ]]-- +--[[ =============================== ]]-- + + +local function xform_gsub(pattern, replace) + return function(s) + return (gsub(s, pattern, replace)) + end +end + + +local ORCONCAT_XFORMS = { + xform_gsub('#table', 'non-empty table'), + xform_gsub('#list', 'non-empty list'), + xform_gsub('functor', 'functable'), + xform_gsub('list of', '\t%0'), -- tab sorts before any other printable + xform_gsub('table of', '\t%0'), +} + + +--- Concatenate a table of strings using ', ' and ' or ' delimiters. +-- @tparam table alternatives a table of strings +-- @treturn string string of elements from alternatives delimited by ', ' +-- and ' or ' +local function orconcat(alternatives) + if len(alternatives) > 1 then + local t = copy(alternatives) + sort(t, function(a, b) + for _, fn in ipairs(ORCONCAT_XFORMS) do + a, b = fn(a), fn(b) + end + return a < b + end) + local top = remove(t) + t[#t] = t[#t] .. ' or ' .. top + alternatives = t + end + return concat(alternatives, ', ') +end + + +local EXTRAMSG_XFORMS = { + xform_gsub('any value or nil', 'argument'), + xform_gsub('#table', 'non-empty table'), + xform_gsub('#list', 'non-empty list'), + xform_gsub('functor', 'functable'), + xform_gsub('(%S+ of) bool([,%s])', '%1 boolean%2'), + xform_gsub('(%S+ of) func([,%s])', '%1 function%2'), + xform_gsub('(%S+ of) int([,%s])', '%1 integer%2'), + xform_gsub('(%S+ of [^,%s]-)s?([,%s])', '%1s%2'), + xform_gsub('(s, [^,%s]-)s?([,%s])', '%1s%2'), + xform_gsub('(of .-)s? or ([^,%s]-)s? ', '%1s or %2s '), +} + + +local function extramsg_mismatch(i, expectedtypes, argu, key) + local actual, actualtype + + if type(i) ~= 'number' then + -- Support the old (expectedtypes, actual, key) calling convention. + expectedtypes, actual, key, argu = i, expectedtypes, argu, nil + actualtype = _type(actual) + else + -- Support the new (i, expectedtypes, argu) convention, which can + -- diagnose missing arguments properly. + actual = argu[i] + if i > argu.n then + actualtype = 'no value' + else + actualtype = _type(actual) or type(actual) + end + end + + -- Tidy up actual type for display. + if actualtype == 'string' and sub(actual, 1, 1) == ':' then + actualtype = actual + elseif type(actual) == 'table' then + if actualtype == 'table' and (getmetatable(actual) or {}).__call ~= nil then + actualtype = 'functable' + elseif next(actual) == nil then + local matchstr = ',' .. concat(expectedtypes, ',') .. ',' + if actualtype == 'table' and matchstr == ',#list,' then + actualtype = 'empty list' + elseif actualtype == 'table' or match(matchstr, ',#') then + actualtype = 'empty ' .. actualtype + end + end + end + + if key then + actualtype = actualtype .. ' at index ' .. tostring(key) + end + + -- Tidy up expected types for display. + local expectedstr = expectedtypes + if type(expectedtypes) == 'table' then + local t = {} + for i, v in ipairs(expectedtypes) do + if v == 'func' then + t[i] = 'function' + elseif v == 'bool' then + t[i] = 'boolean' + elseif v == 'int' then + t[i] = 'integer' + elseif v == 'any' then + t[i] = 'any value' + elseif v == 'file' then + t[i] = 'FILE*' + elseif not key then + t[i] = match(v, '(%S+) of %S+') or v + else + t[i] = v + end + end + expectedstr = orconcat(t) .. ' expected' + for _, fn in ipairs(EXTRAMSG_XFORMS) do + expectedstr = fn(expectedstr) + end + end + + if expectedstr == 'integer expected' and tonumber(actual) then + if tointeger(actual) == nil then + return actualtype .. ' has no integer representation' + end + end + + return expectedstr .. ', got ' .. actualtype +end + + +--- Compare *check* against type of *actual*. *check* must be a single type +-- @string expected extended type name expected +-- @param actual object being typechecked +-- @treturn boolean `true` if *actual* is of type *check*, otherwise +-- `false` +local function checktype(expected, actual) + if expected == 'any' and actual ~= nil then + return true + elseif expected == 'file' and io_type(actual) == 'file' then + return true + elseif expected == 'functable' or expected == 'callable' or expected == 'functor' then + if (getmetatable(actual) or {}).__call ~= nil then + return true + end + end + + local actualtype = type(actual) + if expected == actualtype then + return true + elseif expected == 'bool' and actualtype == 'boolean' then + return true + elseif expected == '#table' then + if actualtype == 'table' and next(actual) then + return true + end + elseif expected == 'func' or expected == 'callable' then + if actualtype == 'function' then + return true + end + elseif expected == 'int' or expected == 'integer' then + if actualtype == 'number' and actual == floor(actual) then + return true + end + elseif type(expected) == 'string' and sub(expected, 1, 1) == ':' then + if expected == actual then + return true + end + end + + actualtype = _type(actual) + if expected == actualtype then + return true + elseif expected == 'list' or expected == '#list' then + if actualtype == 'table' or actualtype == 'List' then + local n, count = len(actual), 0 + local i = next(actual) + repeat + if i ~= nil then + count = count + 1 + end + i = next(actual, i) + until i == nil or count > n + if count == n and (expected == 'list' or count > 0) then + return true + end + end + elseif expected == 'object' then + if actualtype ~= 'table' and type(actual) == 'table' then + return true + end + end + + return false +end + + +local function typesplit(typespec) + if type(typespec) == 'string' then + typespec = split(gsub(typespec, '%s+or%s+', '|'), '%s*|%s*') + end + local r, seen, add_nil = {}, {}, false + for _, v in ipairs(typespec) do + local m = match(v, '^%?(.+)$') + if m then + add_nil, v = true, m + end + if not seen[v] then + r[#r + 1] = v + seen[v] = true + end + end + if add_nil then + r[#r + 1] = 'nil' + end + return r +end + + +local function checktypespec(expected, actual) + expected = typesplit(expected) + + -- Check actual has one of the types from expected + for _, expect in ipairs(expected) do + local container, contents = match(expect, '^(%S+) of (%S-)s?$') + container = container or expect + + -- Does the type of actual check out? + local ok = checktype(container, actual) + + -- For 'table of things', check all elements are a thing too. + if ok and contents and type(actual) == 'table' then + for k, v in pairs(actual) do + if not checktype(contents, v) then + return nil, extramsg_mismatch(expected, v, k) + end + end + end + if ok then + return true + end + end + + return nil, extramsg_mismatch(expected, actual) +end + + + +--[[ ================================== ]]-- +--[[ Implementation of function checks. ]]-- +--[[ ================================== ]]-- + + +local function extramsg_toomany(bad, expected, actual) + local s = 'no more than %d %s%s expected, got %d' + return format(s, expected, bad, expected == 1 and '' or 's', actual) +end + + +--- Strip trailing ellipsis from final argument if any, storing maximum +-- number of values that can be matched directly in `t.maxvalues`. +-- @tparam table t table to act on +-- @string v element added to *t*, to match against ... suffix +-- @treturn table *t* with ellipsis stripped and maxvalues field set +local function markdots(t, v) + return (gsub(v, '%.%.%.$', function() + t.dots = true return '' + end)) +end + + +--- Calculate permutations of type lists with and without [optionals]. +-- @tparam table t a list of expected types by argument position +-- @treturn table set of possible type lists +local function permute(t) + if t[#t] then + t[#t] = gsub(t[#t], '%]%.%.%.$', '...]') + end + + local p = {{}} + for _, v in ipairs(t) do + local optional = match(v, '%[(.+)%]') + + if optional == nil then + -- Append non-optional type-spec to each permutation. + for b = 1, #p do + insert(p[b], markdots(p[b], v)) + end + else + -- Duplicate all existing permutations, and add optional type-spec + -- to the unduplicated permutations. + local o = #p + for b = 1, o do + p[b + o] = copy(p[b]) + insert(p[b], markdots(p[b], optional)) + end + end + end + return p +end + + +local function projectuniq(fkey, tt) + -- project + local t = {} + for _, u in ipairs(tt) do + t[#t + 1] = u[fkey] + end + + -- split and remove duplicates + local r, s = {}, {} + for _, e in ipairs(t) do + for _, v in ipairs(typesplit(e)) do + if s[v] == nil then + r[#r + 1], s[v] = v, true + end + end + end + return r +end + + +local function parsetypes(typespec) + local r, permutations = {}, permute(typespec) + for i = 1, #permutations[1] do + r[i] = projectuniq(i, permutations) + end + r.dots = permutations[1].dots + return r +end + + + +local argcheck = (function() + if _debug.argcheck then + + return function(name, i, expected, actual, level) + level = level or 1 + local _, err = checktypespec(expected, actual) + if err then + argerror(name, i, err, level + 1) + end + end + + else + + return function(...) + return ... + end + + end +end)() + + +local argscheck = (function() + if _debug.argcheck then + + --- Return index of the first mismatch between types and values, or `nil`. + -- @tparam table typelist a list of expected types + -- @tparam table valuelist a table of arguments to compare + -- @treturn int|nil position of first mismatch in *typelist* + local function typematch(typelist, valuelist) + local n = #typelist + for i = 1, n do -- normal parameters + local ok = pcall(argcheck, 'pcall', i, typelist[i], valuelist[i]) + if not ok or i > valuelist.n then + return i + end + end + for i = n + 1, valuelist.n do -- additional values against final type + local ok = pcall(argcheck, 'pcall', i, typelist[n], valuelist[i]) + if not ok then + return i + end + end + end + + + --- Diagnose mismatches between *valuelist* and type *permutations*. + -- @tparam table valuelist list of actual values to be checked + -- @tparam table argt table of precalculated values and handler functiens + local function diagnose(valuelist, argt) + local permutations = argt.permutations + local bestmismatch, t + + bestmismatch = 0 + for i, typelist in ipairs(permutations) do + local mismatch = typematch(typelist, valuelist) + if mismatch == nil then + bestmismatch, t = nil, nil + break -- every *valuelist* matched types from this *typelist* + elseif mismatch > bestmismatch then + bestmismatch, t = mismatch, permutations[i] + end + end + + if bestmismatch ~= nil then + -- Report an error for all possible types at bestmismatch index. + local i, expected = bestmismatch, nil + if t.dots and i > #t then + expected = typesplit(t[#t]) + else + expected = projectuniq(i, permutations) + end + + -- This relies on the `permute()` algorithm leaving the longest + -- possible permutation(with dots if necessary) at permutations[1]. + local typelist = permutations[1] + + -- For 'container of things', check all elements are a thing too. + if typelist[i] then + local contents = match(typelist[i], '^%S+ of (%S-)s?$') + if contents and type(valuelist[i]) == 'table' then + for k, v in pairs(valuelist[i]) do + if not checktype(contents, v) then + argt.badtype(i, extramsg_mismatch(expected, v, k), 3) + end + end + end + end + + -- Otherwise the argument type itself was mismatched. + if t.dots or #t >= valuelist.n then + argt.badtype(i, extramsg_mismatch(i, expected, valuelist), 3) + end + end + + local n = valuelist.n + t = t or permutations[1] + if t and t.dots == nil and n > #t then + argt.badtype(#t + 1, extramsg_toomany(argt.bad, #t, n), 3) + end + end + + + -- Pattern to extract: fname([types]?[, types]*) + local args_pattern = '^%s*([%w_][%.%:%d%w_]*)%s*%(%s*(.*)%s*%)' + + return function(decl, inner) + -- Parse 'fname(argtype, argtype, argtype...)'. + local fname, argtypes = match(decl, args_pattern) + if argtypes == '' then + argtypes = {} + elseif argtypes then + argtypes = split(argtypes, '%s*,%s*') + else + fname = match(decl, '^%s*([%w_][%.%:%d%w_]*)') + end + + -- Precalculate vtables once to make multiple calls faster. + local input = { + bad = 'argument', + badtype = function(i, extramsg, level) + level = level or 1 + argerror(fname, i, extramsg, level + 1) + end, + permutations = permute(argtypes), + } + + -- Parse '... => returntype, returntype, returntype...'. + local output, returntypes = nil, match(decl, '=>%s*(.+)%s*$') + if returntypes then + local i, permutations = 0, {} + for _, group in ipairs(split(returntypes, '%s+or%s+')) do + returntypes = split(group, ',%s*') + for _, t in ipairs(permute(returntypes)) do + i = i + 1 + permutations[i] = t + end + end + + -- Ensure the longest permutation is first in the list. + sort(permutations, function(a, b) + return #a > #b + end) + + output = { + bad = 'result', + badtype = function(i, extramsg, level) + level = level or 1 + resulterror(fname, i, gsub(extramsg, 'argument( expected,)', 'result%1'), level + 1) + end, + permutations = permutations, + } + end + + local wrap_function = function(my_inner) + return function(...) + local argt = pack(...) + + -- Don't check type of self if fname has a ':' in it. + if find(fname, ':') then + remove(argt, 1) + argt.n = argt.n - 1 + end + + -- Diagnose bad inputs. + diagnose(argt, input) + + -- Propagate outer environment to inner function. + if type(my_inner) == 'table' then + setfenv((getmetatable(my_inner) or {}).__call, getfenv(1)) + else + setfenv(my_inner, getfenv(1)) + end + + -- Execute. + local results = pack(my_inner(...)) + + -- Diagnose bad outputs. + if returntypes then + diagnose(results, output) + end + + return unpack(results, 1, results.n) + end + end + + if inner then + return wrap_function(inner) + else + return setmetatable({}, { + __concat = function(_, concat_inner) + return wrap_function(concat_inner) + end + }) + end + end + + else + + -- Turn off argument checking if _debug is false, or a table containing + -- a false valued `argcheck` field. + return function(_, inner) + if inner then + return inner + else + return setmetatable({}, { + __concat = function(_, concat_inner) + return concat_inner + end + }) + end + end + + end +end)() + + +local T = types + +return setmetatable({ + --- Add this to any stack frame offsets when argchecks are in force. + -- @int ARGCHECK_FRAME + ARGCHECK_FRAME = ARGCHECK_FRAME, + + --- Check the type of an argument against expected types. + -- Equivalent to luaL_argcheck in the Lua C API. + -- + -- Call `argerror` if there is a type mismatch. + -- + -- Argument `actual` must match one of the types from in `expected`, each + -- of which can be the name of a primitive Lua type, a stdlib object type, + -- or one of the special options below: + -- + -- #table accept any non-empty table + -- any accept any non-nil argument type + -- callable accept a function or a functable + -- file accept an open file object + -- func accept a function + -- function accept a function + -- functable accept an object with a __call metamethod + -- int accept an integer valued number + -- list accept a table where all keys are a contiguous 1-based integer range + -- #list accept any non-empty list + -- object accept any std.Object derived type + -- :foo accept only the exact string ':foo', works for any :-prefixed string + -- + -- The `:foo` format allows for type-checking of self-documenting + -- boolean-like constant string parameters predicated on `nil` versus + -- `:option` instead of `false` versus `true`. Or you could support + -- both: + -- + -- argcheck('table.copy', 2, 'boolean|:nometa|nil', nometa) + -- + -- A very common pattern is to have a list of possible types including + -- 'nil' when the argument is optional. Rather than writing long-hand + -- as above, prepend a question mark to the list of types and omit the + -- explicit 'nil' entry: + -- + -- argcheck('table.copy', 2, '?boolean|:nometa', predicate) + -- + -- Normally, you should not need to use the `level` parameter, as the + -- default is to blame the caller of the function using `argcheck` in + -- error messages; which is almost certainly what you want. + -- @function argcheck + -- @string name function to blame in error message + -- @int i argument number to blame in error message + -- @string expected specification for acceptable argument types + -- @param actual argument passed + -- @int[opt=2] level call stack level to blame for the error + -- @usage + -- local function case(with, branches) + -- argcheck('std.functional.case', 2, '#table', branches) + -- ... + argcheck = checktypes( + 'argcheck', T.string, T.integer, T.string, T.accept, opt(T.integer) + ) .. argcheck, + + --- Raise a bad argument error. + -- Equivalent to luaL_argerror in the Lua C API. This function does not + -- return. The `level` argument behaves just like the core `error` + -- function. + -- @function argerror + -- @string name function to callout in error message + -- @int i argument number + -- @string[opt] extramsg additional text to append to message inside parentheses + -- @int[opt=1] level call stack level to blame for the error + -- @see resulterror + -- @see extramsg_mismatch + -- @usage + -- local function slurp(file) + -- local h, err = input_handle(file) + -- if h == nil then + -- argerror('std.io.slurp', 1, err, 2) + -- end + -- ... + argerror = checktypes( + 'argerror', T.string, T.integer, T.accept, opt(T.integer) + ) .. argerror, + + --- Wrap a function definition with argument type and arity checking. + -- In addition to checking that each argument type matches the corresponding + -- element in the *types* table with `argcheck`, if the final element of + -- *types* ends with an ellipsis, remaining unchecked arguments are checked + -- against that type: + -- + -- format = argscheck('string.format(string, ?any...)', string.format) + -- + -- A colon in the function name indicates that the argument type list does + -- not have a type for `self`: + -- + -- format = argscheck('string:format(?any...)', string.format) + -- + -- If an argument can be omitted entirely, then put its type specification + -- in square brackets: + -- + -- insert = argscheck('table.insert(table, [int], ?any)', table.insert) + -- + -- Similarly return types can be checked with the same list syntax as + -- arguments: + -- + -- len = argscheck('string.len(string) => int', string.len) + -- + -- Additionally, variant return type lists can be listed like this: + -- + -- open = argscheck('io.open(string, ?string) => file or nil, string', + -- io.open) + -- + -- @function argscheck + -- @string decl function type declaration string + -- @func inner function to wrap with argument checking + -- @return function + -- @usage + -- local case = argscheck('std.functional.case(?any, #table) => [any...]', + -- function(with, branches) + -- ... + -- end) + -- + -- -- Alternatively, as an annotation: + -- local case = argscheck 'std.functional.case(?any, #table) => [any...]' .. + -- function(with, branches) + -- ... + -- end + argscheck = checktypes( + 'argscheck', T.string, opt(T.callable) + ) .. argscheck, + + --- Checks the type of *actual* against the *expected* typespec + -- @function check + -- @tparam string expected expected typespec + -- @param actual object being typechecked + -- @treturn[1] bool `true`, if *actual* matches *expected* + -- @return[2] `nil` + -- @treturn[2] string an @{extramsg_mismatch} format error message, otherwise + -- @usage + -- --> stdin:2: string or number expected, got empty table + -- assert(check('string|number', {})) + check = checktypespec, + + --- Format a type mismatch error. + -- @function extramsg_mismatch + -- @int[opt] i index of *argu* to be matched with + -- @string expected a pipe delimited list of matchable types + -- @tparam table argu packed table of all arguments + -- @param[opt] key erroring container element key + -- @treturn string formatted *extramsg* for this mismatch for @{argerror} + -- @see argerror + -- @see resulterror + -- @usage + -- if fmt ~= nil and type(fmt) ~= 'string' then + -- argerror('format', 1, extramsg_mismatch(1, '?string', argu)) + -- end + extramsg_mismatch = function(i, expected, argu, key) + if tointeger(i) and type(expected) == 'string' then + expected = typesplit(expected) + else + -- support old (expected, actual, key) calling convention + i = typesplit(i) + end + return extramsg_mismatch(i, expected, argu, key) + end, + + --- Format a too many things error. + -- @function extramsg_toomany + -- @string bad the thing there are too many of + -- @int expected maximum number of *bad* things expected + -- @int actual actual number of *bad* things that triggered the error + -- @see argerror + -- @see resulterror + -- @see extramsg_mismatch + -- @usage + -- if select('#', ...) > 7 then + -- argerror('sevenses', 8, extramsg_toomany('argument', 7, select('#', ...))) + -- end + extramsg_toomany = extramsg_toomany, + + --- Create an @{ArgCheck} predicate for an optional argument. + -- + -- This function satisfies the @{ArgCheck} interface in order to be + -- useful as an argument to @{argscheck} when a particular argument + -- is optional. + -- @function opt + -- @tparam ArgCheck ... type predicate callables + -- @treturn ArgCheck a new function that calls all passed + -- predicates, and combines error messages if all fail + -- @usage + -- getfenv = argscheck( + -- 'getfenv', opt(types.integer, types.callable) + -- ) .. getfenv + opt = opt, + + --- Compact permutation list into a list of valid types at each argument. + -- Eliminate bracketed types by combining all valid types at each position + -- for all permutations of *typelist*. + -- @function parsetypes + -- @tparam list types a normalized list of type names + -- @treturn list valid types for each positional parameter + parsetypes = parsetypes, + + --- Raise a bad result error. + -- Like @{argerror} for bad results. This function does not + -- return. The `level` argument behaves just like the core `error` + -- function. + -- @function resulterror + -- @string name function to callout in error message + -- @int i result number + -- @string[opt] extramsg additional text to append to message inside parentheses + -- @int[opt=1] level call stack level to blame for the error + -- @usage + -- local function slurp(file) + -- ... + -- if type(result) ~= 'string' then + -- resulterror('std.io.slurp', 1, err, 2) + -- end + resulterror = checktypes( + 'resulterror', T.string, T.integer, T.accept, opt(T.integer) + ) .. resulterror, + + --- A collection of @{ArgCheck} functions used by `normalize` APIs. + -- @table types + -- @tfield ArgCheck accept always succeeds + -- @tfield ArgCheck callable accept a function or functable + -- @tfield ArgCheck integer accept integer valued number + -- @tfield ArgCheck nil accept only `nil` + -- @tfield ArgCheck table accept any table + -- @tfield ArgCheck value accept any non-`nil` value + types = types, + + --- Split a typespec string into a table of normalized type names. + -- @function typesplit + -- @tparam string|table either `"?bool|:nometa"` or `{"boolean", ":nometa"}` + -- @treturn table a new list with duplicates removed and leading '?'s + -- replaced by a 'nil' element + typesplit = typesplit, + +}, { + + --- Metamethods + -- @section metamethods + + --- Lazy loading of typecheck modules. + -- Don't load everything on initial startup, wait until first attempt + -- to access a submodule, and then load it on demand. + -- @function __index + -- @string name submodule name + -- @treturn table|nil the submodule that was loaded to satisfy the missing + -- `name`, otherwise `nil` if nothing was found + -- @usage + -- local version = require 'typecheck'.version + __index = function(self, name) + local ok, t = pcall(require, 'typecheck.' .. name) + if ok then + rawset(self, name, t) + return t + end + end, +}) + + +--- Types +-- @section types + +--- Signature of an @{argscheck} callable. +-- @function ArgCheck +-- @tparam table argu a packed table (including `n` field) of all arguments +-- @int index into @argu* for argument to action +-- @return[1] nothing, to accept `argu[i]` +-- @treturn[2] string error message, to reject `argu[i]` immediately +-- @treturn[3] string the expected type of `argu[i]` +-- @treturn[3] string a description of rejected `argu[i]` +-- @usage +-- len = argscheck('len', any(types.table, types.string)) .. len + +--- Signature of a @{check} type predicate callable. +-- @function Predicate +-- @param x object to action +-- @treturn boolean `true` if *x* is of the expected type, otherwise `false` +-- @treturn[opt] string description of the actual type for error message diff --git a/engine/src/coordinator.cpp b/engine/src/coordinator.cpp index 10bc74bad..31f91c8a4 100644 --- a/engine/src/coordinator.cpp +++ b/engine/src/coordinator.cpp @@ -29,6 +29,9 @@ #include // for move #include // for vector<> +#include +#include + #include // for Json, Duration, Logger #include // for HandlerType, Request #include // for Registrar @@ -36,6 +39,9 @@ #include // for Trigger using namespace cloe; // NOLINT(build/namespaces) +#include "lua_action.hpp" +#include "lua_api.hpp" + namespace engine { template @@ -50,7 +56,8 @@ void to_json(Json& j, const HistoryTrigger& t) { j["at"] = t.when; } -Coordinator::Coordinator() : executer_registrar_(trigger_registrar(Source::TRIGGER)) {} +Coordinator::Coordinator(sol::state_view lua) + : lua_(lua), executer_registrar_(trigger_registrar(Source::TRIGGER)) {} class TriggerRegistrar : public cloe::TriggerRegistrar { public: @@ -170,6 +177,12 @@ void Coordinator::register_event(const std::string& key, EventFactoryPtr&& ef, std::bind(&Coordinator::execute_trigger, this, std::placeholders::_1, std::placeholders::_2)); } +sol::table Coordinator::register_lua_table(const std::string& field) { + auto tbl = lua_.create_table(); + luat_cloe_engine_plugins(lua_)[field] = tbl; + return tbl; +} + cloe::CallbackResult Coordinator::execute_trigger(TriggerPtr&& t, const Sync& sync) { logger()->debug("Execute trigger {}", inline_json(*t)); auto result = (t->action())(sync, *executer_registrar_); @@ -179,6 +192,28 @@ cloe::CallbackResult Coordinator::execute_trigger(TriggerPtr&& t, const Sync& sy return result; } +void Coordinator::execute_action_from_lua(const Sync& sync, const sol::object& obj) { + // TODO: Make this trackable by making it a proper trigger and using execute trigger + // instead of calling the action here directly. + auto ap = make_action(obj); + (*ap)(sync, *executer_registrar_); +} + +void Coordinator::insert_trigger_from_lua(const Sync& sync, const sol::object& obj) { + store_trigger(make_trigger(obj), sync); +} + +size_t Coordinator::process_pending_lua_triggers(const Sync& sync) { + auto triggers = sol::object(luat_cloe_engine_initial_input(lua_)["triggers"]); + size_t count = 0; + for (auto& kv : triggers.as()) { + store_trigger(make_trigger(kv.second), sync); + count++; + } + luat_cloe_engine_initial_input(lua_)["triggers_processed"] = count; + return count; +} + size_t Coordinator::process_pending_web_triggers(const Sync& sync) { // The only thing we need to do here is distribute the triggers from the // input queue into their respective storage locations. We are responsible @@ -294,6 +329,30 @@ TriggerPtr Coordinator::make_trigger(Source s, const Conf& c) const { return t; } +ActionPtr Coordinator::make_action(const sol::object& lua) const { + if (lua.get_type() == sol::type::function) { + return std::make_unique("luafunction", lua); + } else { + return make_action(Conf{Json(lua)}); + } +} + +TriggerPtr Coordinator::make_trigger(const sol::table& lua) const { + sol::optional label = lua["label"]; + EventPtr ep = make_event(Conf{Json(lua["event"])}); + ActionPtr ap = make_action(sol::object(lua["action"])); + sol::optional action_source = lua["action_source"]; + if (!label && action_source) { + label = *action_source; + } else { + label = ""; + } + sol::optional sticky = lua["sticky"]; + auto tp = std::make_unique(*label, Source::LUA, std::move(ep), std::move(ap)); + tp->set_sticky(sticky.value_or(false)); + return tp; +} + void Coordinator::queue_trigger(TriggerPtr&& t) { if (t == nullptr) { // This only really happens if a trigger is an optional trigger. @@ -303,4 +362,18 @@ void Coordinator::queue_trigger(TriggerPtr&& t) { input_queue_.emplace_back(std::move(t)); } +void register_usertype_coordinator(sol::table& lua, const Sync& sync) { + // clang-format off + lua.new_usertype("Coordinator", + sol::no_constructor, + "insert_trigger", [&sync](Coordinator& self, const sol::object& obj) { + self.insert_trigger_from_lua(sync, obj); + }, + "execute_action", [&sync](Coordinator& self, const sol::object& obj) { + self.execute_action_from_lua(sync, obj); + } + ); + // clang-format on +} + } // namespace engine diff --git a/engine/src/coordinator.hpp b/engine/src/coordinator.hpp index a52452fd2..7bf21605b 100644 --- a/engine/src/coordinator.hpp +++ b/engine/src/coordinator.hpp @@ -30,7 +30,10 @@ #include // for string #include // for vector<> -#include // for Registrar +#include // for state_view +#include // for table + +#include // for Registrar, DataBroker #include // for Trigger, Action, Event, ... namespace engine { @@ -95,7 +98,7 @@ struct HistoryTrigger { */ class Coordinator { public: - Coordinator(); + Coordinator(sol::state_view lua); const std::vector& history() const { return history_; } @@ -104,6 +107,8 @@ class Coordinator { void register_event(const std::string& key, cloe::EventFactoryPtr&& ef, std::shared_ptr storage); + sol::table register_lua_table(const std::string& field); + std::shared_ptr trigger_registrar(cloe::Source s); void enroll(cloe::Registrar& r); @@ -116,11 +121,18 @@ class Coordinator { */ cloe::Duration process(const cloe::Sync&); + size_t process_pending_lua_triggers(const cloe::Sync& sync); size_t process_pending_web_triggers(const cloe::Sync& sync); + + void insert_trigger_from_lua(const cloe::Sync& sync, const sol::object& obj); + void execute_action_from_lua(const cloe::Sync& sync, const sol::object& obj); + protected: + cloe::ActionPtr make_action(const sol::object& lua) const; cloe::ActionPtr make_action(const cloe::Conf& c) const; cloe::EventPtr make_event(const cloe::Conf& c) const; cloe::TriggerPtr make_trigger(cloe::Source s, const cloe::Conf& c) const; + cloe::TriggerPtr make_trigger(const sol::table& tbl) const; void queue_trigger(cloe::Source s, const cloe::Conf& c) { queue_trigger(make_trigger(s, c)); } void queue_trigger(cloe::TriggerPtr&& tp); void store_trigger(cloe::TriggerPtr&& tp, const cloe::Sync& sync); @@ -136,6 +148,7 @@ class Coordinator { // Factories: std::map actions_; std::map events_; + sol::state_view lua_; // Execution: std::shared_ptr executer_registrar_; @@ -151,4 +164,6 @@ class Coordinator { std::vector history_; }; +void register_usertype_coordinator(sol::table& lua, const cloe::Sync& sync); + } // namespace engine diff --git a/engine/src/lua_action.cpp b/engine/src/lua_action.cpp new file mode 100644 index 000000000..5c8a2bfc2 --- /dev/null +++ b/engine/src/lua_action.cpp @@ -0,0 +1,89 @@ +/* + * Copyright 2022 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file lua_action.cpp + * \see lua_action.hpp + */ + +#include "lua_action.hpp" + +#include +#include + +#include +#include + +#include "lua_api.hpp" + +namespace engine { +namespace actions { + +cloe::CallbackResult LuaFunction::operator()(const cloe::Sync& sync, cloe::TriggerRegistrar&) { + logger()->trace("Running lua function."); + auto result = func_(std::ref(sync)); + if (!result.valid()) { + throw cloe::Error("error executing Lua function: {}", sol::error{result}.what()); + } + // Return false from a pinned action to remove it. + if (result.return_count() > 0 && !result.get()) { + return cloe::CallbackResult::Unpin; + } + return cloe::CallbackResult::Ok; +} + +cloe::CallbackResult Lua::operator()(const cloe::Sync&, cloe::TriggerRegistrar&) { + logger()->trace("Running lua script."); + auto result = lua_.script(script_); + if (!result.valid()) { + throw cloe::Error("error executing Lua function: {}", sol::error{result}.what()); + } + // Return false from a pinned action to remove it. + if (result.return_count() > 0 && !result.get()) { + return cloe::CallbackResult::Unpin; + } + return cloe::CallbackResult::Ok; +} + +void Lua::to_json(cloe::Json& j) const { + j = cloe::Json{ + {"script", script_}, + }; +} + +cloe::TriggerSchema LuaFactory::schema() const { + static const char* desc = "lua script to execute"; + return cloe::TriggerSchema{ + this->name(), this->description(), cloe::InlineSchema(desc, cloe::JsonType::string, true), + cloe::Schema{ + {"script", cloe::make_prototype("lua script to execute")}, + }}; +} + +cloe::ActionPtr LuaFactory::make(const cloe::Conf& c) const { + auto script = c.get("script"); + return std::make_unique(name(), script, lua_); +} + +cloe::ActionPtr LuaFactory::make(const std::string& s) const { + return make(cloe::Conf{cloe::Json{ + {"script", s}, + }}); +} + +} // namespace actions +} // namespace engine diff --git a/engine/src/lua_action.hpp b/engine/src/lua_action.hpp new file mode 100644 index 000000000..a7b99730a --- /dev/null +++ b/engine/src/lua_action.hpp @@ -0,0 +1,84 @@ +/* + * Copyright 2022 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file lua_action.hpp + * \see lua_action.cpp + * + * This file contains several types that make use of cloe::Lua. + */ + +#pragma once + +#include +#include + +#include // for Logger, Json, Conf, ... +#include // for Action, ActionFactory, ... + +namespace engine { +namespace actions { + +class LuaFunction : public cloe::Action { + public: + LuaFunction(const std::string& name, sol::function fun) : Action(name), func_(fun) {} + + cloe::ActionPtr clone() const override { + return std::make_unique(this->name(), func_); + } + + cloe::CallbackResult operator()(const cloe::Sync& sync, cloe::TriggerRegistrar&) override; + + void to_json(cloe::Json& j) const override { j = cloe::Json{}; } + + private: + sol::protected_function func_; +}; + +class Lua : public cloe::Action { + public: + Lua(const std::string& name, const std::string& script, sol::state_view lua) + : Action(name), script_(script), lua_(lua) {} + + cloe::ActionPtr clone() const override { return std::make_unique(name(), script_, lua_); } + + cloe::CallbackResult operator()(const cloe::Sync&, cloe::TriggerRegistrar&) override; + + protected: + void to_json(cloe::Json& j) const override; + + private: + std::string script_; + sol::state_view lua_; +}; + +class LuaFactory : public cloe::ActionFactory { + public: + using ActionType = Lua; + explicit LuaFactory(sol::state_view lua) + : cloe::ActionFactory("lua", "run a lua script"), lua_(lua) { + } + cloe::TriggerSchema schema() const override; + cloe::ActionPtr make(const cloe::Conf& c) const override; + cloe::ActionPtr make(const std::string& s) const override; + + private: + sol::state_view lua_; +}; + +} // namespace actions +} // namespace engine diff --git a/engine/src/lua_api.cpp b/engine/src/lua_api.cpp new file mode 100644 index 000000000..6043341ab --- /dev/null +++ b/engine/src/lua_api.cpp @@ -0,0 +1,49 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "lua_api.hpp" + +#include // for path + +#include // for logger::get +#include // for state_view + +namespace cloe { + +sol::protected_function_result lua_safe_script_file(sol::state_view& lua, + const std::filesystem::path& filepath) { + auto file = std::filesystem::path(filepath); + auto dir = file.parent_path().generic_string(); + if (dir.empty()) { + dir = "."; + } + + auto state = luat_cloe_engine_state(lua); + auto old_file = state["current_script_file"]; + auto old_dir = state["current_script_dir"]; + state["scripts_loaded"].get().add(file.generic_string()); + state["current_script_file"] = file.generic_string(); + state["current_script_dir"] = dir; + logger::get("cloe")->info("Loading {}", file.generic_string()); + auto result = lua.safe_script_file(file.generic_string(), sol::script_pass_on_error); + state["current_script_file"] = old_file; + state["current_script_dir"] = old_dir; + return result; +} + +} // namespace cloe diff --git a/engine/src/lua_api.hpp b/engine/src/lua_api.hpp new file mode 100644 index 000000000..1af36e676 --- /dev/null +++ b/engine/src/lua_api.hpp @@ -0,0 +1,73 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * This file contains functions for dealing with Lua *after* it has been + * set up. + * + * \file lua_api.hpp + * \see lua_api.cpp + */ + +#pragma once + +#include // for std::filesystem::path + +#include // for protected_function_result +#include // for state_view + +namespace cloe { + +/** + * Safely load and run a user Lua script. + */ +[[nodiscard]] sol::protected_function_result lua_safe_script_file( + sol::state_view& lua, const std::filesystem::path& filepath); + +/** + * Return the cloe-engine table as it is exported into Lua. + * + * If you make any changes to these paths, make sure to reflect it: + * + * engine/lua/cloe-engine/init.lua + * + */ +[[nodiscard]] inline auto luat_cloe_engine(sol::state_view& lua) { + return lua["package"]["loaded"]["cloe-engine"]; +} + +[[nodiscard]] inline auto luat_cloe_engine_fs(sol::state_view& lua) { + return lua["package"]["loaded"]["cloe-engine.fs"]; +} + +[[nodiscard]] inline auto luat_cloe_engine_types(sol::state_view& lua) { + return lua["package"]["loaded"]["cloe-engine.types"]; +} + +[[nodiscard]] inline auto luat_cloe_engine_initial_input(sol::state_view& lua) { + return lua["package"]["loaded"]["cloe-engine"]["initial_input"]; +} + +[[nodiscard]] inline auto luat_cloe_engine_state(sol::state_view& lua) { + return lua["package"]["loaded"]["cloe-engine"]["state"]; +} + +[[nodiscard]] inline auto luat_cloe_engine_plugins(sol::state_view& lua) { + return lua["package"]["loaded"]["cloe-engine"]["plugins"]; +} + +} // namespace cloe diff --git a/engine/src/lua_setup.cpp b/engine/src/lua_setup.cpp new file mode 100644 index 000000000..b6b66d48e --- /dev/null +++ b/engine/src/lua_setup.cpp @@ -0,0 +1,342 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file stack_lua.cpp + */ + +#include "lua_setup.hpp" + +#include // for path + +#include // for state_view + +#include // for split_string + +#include // for Json(sol::object) +#include // for join_vector + +#include "error_handler.hpp" // for format_cloe_error +#include "lua_api.hpp" +#include "stack.hpp" +#include "utility/command.hpp" // for CommandExecuter, CommandResult + +// This variable is set from CMakeLists.txt, but in case it isn't, +// we will assume that the server is disabled. +#ifndef CLOE_ENGINE_WITH_SERVER +#define CLOE_ENGINE_WITH_SERVER 0 +#endif + +#ifndef CLOE_LUA_PATH +#define CLOE_LUA_PATH "CLOE_LUA_PATH" +#endif + +namespace cloe { + +namespace { + +void cloe_api_log(const std::string& level, const std::string& prefix, const std::string& msg) { + auto lev = logger::into_level(level); + auto log = cloe::logger::get(prefix.empty() ? prefix : "lua"); + log->log(lev, msg.c_str()); +} + +std::tuple cloe_api_exec(sol::object obj, sol::this_state s) { + // FIXME: This is not a very nice function... + Command cmd; + cmd.from_conf(fable::Conf{Json(obj)}); + + engine::CommandExecuter exec(cloe::logger::get("lua")); + auto result = exec.run_and_release(cmd); + if (cmd.mode() != cloe::Command::Mode::Sync) { + return {sol::lua_nil, sol::lua_nil}; + } + sol::state_view lua(s); + return { + sol::object(lua, sol::in_place, fable::join_vector(result.output, "\n")), + sol::object(lua, sol::in_place, *result.exit_code), + }; +} + +template +inline bool contains(const std::vector& v, const T& x) { + return std::find(v.begin(), v.end(), x) != v.end(); +} + +// Handle the exception. +// +// @param L the lua state, which you can wrap in a state_view if necessary +// @param error the exception, if it exists +// @param desc the what() of the exception or a description saying that we hit the general-case catch(...) +// @return Return value of sol::stack::push() +int lua_exception_handler(lua_State* L, sol::optional maybe_exception, + sol::string_view desc) { + if (maybe_exception) { + const std::exception& err = *maybe_exception; + std::cerr << "Error: " << format_error(err) << std::endl; + } else { + std::cerr << "Error: "; + std::cerr.write(desc.data(), static_cast(desc.size())); + std::cerr << std::endl; + } + + // you must push 1 element onto the stack to be + // transported through as the error object in Lua + // note that Lua -- and 99.5% of all Lua users and libraries + // -- expects a string so we push a single string (in our + // case, the description of the error) + return sol::stack::push(L, desc); +} + +/** + * Add package path to Lua search path. + * + * \see lua_setup_builtin.cpp + */ +void configure_package_path(sol::state_view& lua, const std::vector& paths) { + std::string package_path = lua["package"]["path"]; + for (const std::string& p : paths) { + package_path += ";" + p + "/?.lua"; + package_path += ";" + p + "/?/init.lua"; + } + lua["package"]["path"] = package_path; +} + +/** + * Add Lua package paths so that bundled Lua libaries can be found. + */ +void register_package_path(sol::state_view& lua, const LuaOptions& opt) { + // Setup lua path: + std::vector lua_path{}; + if (!opt.no_system_lua) { + // FIXME(windows): These paths are linux-specific. + lua_path = { + "/usr/local/lib/cloe/lua", + "/usr/lib/cloe/lua", + }; + } + std::string lua_paths = opt.environment->get_or(CLOE_LUA_PATH, ""); + for (auto&& p : utility::split_string(std::move(lua_paths), ":")) { + if (contains(lua_path, p)) { + continue; + } + lua_path.emplace_back(std::move(p)); + } + for (const auto& p : opt.lua_paths) { + if (contains(lua_path, p)) { + continue; + } + lua_path.emplace_back(p); + } + + configure_package_path(lua, lua_path); +} + +/** + * Load "cloe-engine" library into Lua. + * + * This is then available via: + * + * require("cloe-engine") + * + * Any changes you make here should be documented in the Lua meta files. + * + * engine/lua/cloe-engine/init.lua + */ +void register_cloe_engine(sol::state_view& lua, Stack& stack) { + sol::table tbl = lua.create_table(); + + // Initial input will be processed at simulation start. + tbl["initial_input"] = lua.create_table(); + tbl["initial_input"]["triggers"] = lua.create_table(); + tbl["initial_input"]["triggers_processed"] = 0; + tbl["initial_input"]["signal_aliases"] = lua.create_table(); + tbl["initial_input"]["signal_requires"] = lua.create_table(); + + // Plugin access will be made available by Coordinator. + tbl["plugins"] = lua.create_table(); + + // Simulation state will be extended in simulation. + // clang-format off + tbl["state"] = lua.create_table(); + tbl["state"]["report"] = lua.create_table(); + tbl["state"]["stack"] = std::ref(stack); + tbl["state"]["config"] = fable::into_sol_object(lua, stack.active_config()); + tbl["state"]["scheduler"] = sol::lua_nil; + tbl["state"]["current_script_file"] = sol::lua_nil; + tbl["state"]["current_script_dir"] = sol::lua_nil; + tbl["state"]["scripts_loaded"] = lua.create_table(); + tbl["state"]["features"] = lua.create_table_with( + // Version compatibility: + "cloe-0.18.0", true, + "cloe-0.18", true, + "cloe-0.19.0", true, + "cloe-0.19", true, + "cloe-0.20.0", true, + "cloe-0.20", true, + "cloe-0.21.0", true, // nightly + "cloe-0.21", true, // nightly + + // Stackfile versions support: + "cloe-stackfile", true, + "cloe-stackfile-4", true, + "cloe-stackfile-4.0", true, + "cloe-stackfile-4.1", true, + + // Server enabled: + "cloe-server", CLOE_ENGINE_WITH_SERVER != 0 + ); + // clang-format on + +#if 0 + tbl.set_function("is_available", []() { return true; }); + tbl.set_function("get_script_file", [](sol::this_state lua) { + return luat_cloe_engine_state(lua)["current_script_file"]; + }); + tbl.set_function("get_script_dir", [](sol::this_state lua) { + return luat_cloe_engine_state(lua)["current_script_dir"]; + }); + tbl.set_function("get_report", + [](sol::this_state lua) { return luat_cloe_engine_state(lua)["report"]; }); + tbl.set_function("get_scheduler", + [](sol::this_state lua) { return luat_cloe_engine_state(lua)["scheduler"]; }); + tbl.set_function("get_features", + [](sol::this_state lua) { return luat_cloe_engine_state(lua)["features"]; }); + tbl.set_function("get_stack", + [](sol::this_state lua) { return luat_cloe_engine_state(lua)["stack"]; }); +#endif + tbl.set_function("log", cloe_api_log); + tbl.set_function("exec", cloe_api_exec); + + luat_cloe_engine(lua) = tbl; +} + +void register_enum_loglevel(sol::state_view& lua, sol::table& tbl) { + // clang-format off + tbl["LogLevel"] = lua.create_table_with( + "TRACE", "trace", + "DEBUG", "debug", + "INFO", "info", + "WARN", "warn", + "ERROR", "error", + "CRITICAL", "critical" + ); + // clang-format on +} + +/** + * Load "cloe-engine.types" library into Lua. + * + * This is then available via: + * + * require("cloe-engine.types") + * + * Any changes you make here should be documented in the Lua meta files. + * + * engine/lua/cloe-engine/types.lua + */ +void register_cloe_engine_types(sol::state_view& lua) { + sol::table tbl = lua.create_table(); + register_usertype_duration(tbl); + register_usertype_sync(tbl); + register_usertype_stack(tbl); + register_enum_loglevel(lua, tbl); + luat_cloe_engine_types(lua) = tbl; +} + +/** + * Load "cloe-engine.fs" library into Lua. + * + * This is then available via: + * + * require("cloe-engine.fs") + * + * Any changes you make here should be documented in the Lua meta files: + * + * engine/lua/cloe-engine/fs.lua + */ +void register_cloe_engine_fs(sol::state_view& lua) { + sol::table tbl = lua.create_table(); + register_lib_fs(tbl); + luat_cloe_engine_fs(lua) = tbl; +} + +/** + * Add cloe lazy-loader into Lua global namespace. + * + * You can just use `cloe`, and it will auto-require the cloe module. + * If you don't use it, then it won't be loaded. + */ +void register_cloe(sol::state_view& lua) { + // This takes advantage of the `__index` function for metatables, which is called + // when a key can't be found in the original table, here an empty table + // assigned to cloe. It then loads the cloe module, and returns the key + // requested. The next access will no longer trigger this method, because we + // swapped tables. + // + // Effectively, this lazy-loads the cloe library. This allows us to not + // load it and all the other modules it pulls in, which allows us to for + // example, configure those libraries before cloe does. + auto result = lua.safe_script(R"==( + cloe = setmetatable({}, { + __index = function(_, k) + _G["cloe"] = require("cloe") + return _G["cloe"][k] + end + }) + )=="); + assert(result.valid()); +} + +} // anonymous namespace + +sol::state new_lua(const LuaOptions& opt, Stack& stack) { + // clang-format off + sol::state lua; + lua.open_libraries( + sol::lib::base, + sol::lib::coroutine, + sol::lib::debug, + sol::lib::io, + sol::lib::math, + sol::lib::os, + sol::lib::package, + sol::lib::string, + sol::lib::table + ); + lua.set_exception_handler(&lua_exception_handler); + // clang-format on + + register_package_path(lua, opt); + register_cloe_engine(lua, stack); + register_cloe_engine_types(lua); + register_cloe_engine_fs(lua); + if (opt.auto_require_cloe) { + register_cloe(lua); + } + return lua; +} + +void merge_lua(sol::state_view& lua, const std::string& filepath) { + logger::get("cloe")->debug("Load script {}", filepath); + auto result = lua_safe_script_file(lua, std::filesystem::path(filepath)); + if (!result.valid()) { + throw sol::error(result); + } +} + +} // namespace cloe diff --git a/engine/src/lua_setup.hpp b/engine/src/lua_setup.hpp new file mode 100644 index 000000000..0aa84914d --- /dev/null +++ b/engine/src/lua_setup.hpp @@ -0,0 +1,112 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * This file contains function definitions required to set up the Lua API. + * + * \file lua_setup.hpp + */ + +#pragma once + +#include // for ostream, cerr +#include // for shared_ptr<> +#include // for optional<> +#include // for string +#include // for vector<> + +#include +#include +#include + +#include // for Environment + +namespace cloe { + +class Stack; + +struct LuaOptions { + std::shared_ptr environment; + + std::vector lua_paths; + bool no_system_lua = false; + bool auto_require_cloe = false; +}; + +/** + * Create a new lua state. + * + * Currently this requires a fully configured Stack file. + * + * \see cloe::new_stack() + * \see stack_factory.hpp + * \see lua_setup.cpp + */ +sol::state new_lua(const LuaOptions& opt, Stack& s); + +/** + * Merge the provided Lua file into the existing `Stack`, respecting `StackOptions`. + * + * \see lua_setup.cpp + */ +void merge_lua(sol::state_view& lua, const std::string& filepath); + +/** + * Define the filesystem library functions in the given table. + * + * The following functions are made available: + * + * - basename + * - dirname + * - normalize + * - realpath + * - join + * - is_absolute + * - is_relative + * - is_dir + * - is_file + * - is_other + * - exists + * + * \see lua_setup_fs.cpp + */ +void register_lib_fs(sol::table& lua); + +/** + * Define `cloe::Duration` usertype in Lua. + * + * \see cloe/core/duration.hpp from cloe-runtime + * \see lua_setup_duration.cpp + */ +void register_usertype_duration(sol::table& lua); + +/** + * Define `cloe::Sync` usertype in Lua. + * + * \see cloe/sync.hpp from cloe-runtime + * \see lua_setup_sync.cpp + */ +void register_usertype_sync(sol::table& lua); + +/** + * Define `cloe::Stack` usertype in Lua. + * + * \see clua_setup_stack.cpp + */ +void register_usertype_stack(sol::table& lua); + +} // namespace cloe diff --git a/engine/src/lua_setup_duration.cpp b/engine/src/lua_setup_duration.cpp new file mode 100644 index 000000000..9eed7ed01 --- /dev/null +++ b/engine/src/lua_setup_duration.cpp @@ -0,0 +1,47 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "lua_setup.hpp" + +#include // for Duration +#include // for parse_duration + +namespace cloe { + +void register_usertype_duration(sol::table& lua) { + Duration (*parse_duration_ptr)(const std::string&) = ::fable::parse_duration; + std::string (*to_string_ptr)(const Duration&) = ::fable::to_string; + lua.new_usertype<::cloe::Duration>("Duration", + sol::factories(parse_duration_ptr), + sol::meta_function::to_string, to_string_ptr, + sol::meta_function::addition, + sol::resolve(std::chrono::operator+), + sol::meta_function::subtraction, + sol::resolve(std::chrono::operator-), + sol::meta_function::division, + [](const Duration& x, double d) -> Duration { Duration y(x); y /= d; return y; }, + sol::meta_function::multiplication, + [](const Duration& x, double d) -> Duration { Duration y(x); y *= d; return y; }, + "ns", &Duration::count, + "us", [](const Duration& d) -> double { return static_cast(d.count()) / 10e2; }, + "ms", [](const Duration& d) -> double { return static_cast(d.count()) / 10e5; }, + "s", [](const Duration& d) -> double { return static_cast(d.count()) / 10e8; } + ); +} + +} // namespace cloe diff --git a/engine/src/lua_setup_fs.cpp b/engine/src/lua_setup_fs.cpp new file mode 100644 index 000000000..e45a594bd --- /dev/null +++ b/engine/src/lua_setup_fs.cpp @@ -0,0 +1,92 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "lua_setup.hpp" + +#include // for path +namespace fs = std::filesystem; + +#include + +namespace cloe { + +namespace { + +std::string basename(const std::string& file) { return fs::path(file).filename().generic_string(); } + +std::string dirname(const std::string& file) { + return fs::path(file).parent_path().generic_string(); +} + +std::string normalize(const std::string& file) { + return fs::path(file).lexically_normal(); +} + +std::string realpath(const std::string& file) { + std::error_code ec; + auto p = fs::canonical(fs::path(file), ec); + if (ec) { + // FIXME: Implement proper error handling for Lua API. + return ""; + } + return p.generic_string(); +} + +std::string join(const std::string& file_left, const std::string& file_right) { + return (fs::path(file_left) / fs::path(file_right)).generic_string(); +} + +bool is_absolute(const std::string& file) { return fs::path(file).is_absolute(); } + +bool is_relative(const std::string& file) { return fs::path(file).is_relative(); } + +bool is_dir(const std::string& file) { return fs::is_directory(fs::path(file)); } + +bool is_file(const std::string& file) { return fs::is_regular_file(fs::path(file)); } + +bool is_symlink(const std::string& file) { return fs::is_symlink(fs::path(file)); } + +// It is NOT a directory, regular file, or symlink. +// Therefore, it is either a +// - block file, +// - character file, +// - fifo pipe, or +// - socket. +bool is_other(const std::string& file) { return fs::is_other(fs::path(file)); } + +bool exists(const std::string& file) { return fs::exists(fs::path(file)); } + +} // anonymous namespace + +void register_lib_fs(sol::table& lua) { + lua.set_function("basename", basename); + lua.set_function("dirname", dirname); + lua.set_function("normalize", normalize); + lua.set_function("realpath", realpath); + lua.set_function("join", join); + + lua.set_function("is_absolute", is_absolute); + lua.set_function("is_relative", is_relative); + lua.set_function("is_dir", is_dir); + lua.set_function("is_file", is_file); + lua.set_function("is_other", is_other); + + lua.set_function("exists", exists); +} + +} // namespace cloe diff --git a/engine/src/lua_setup_stack.cpp b/engine/src/lua_setup_stack.cpp new file mode 100644 index 000000000..da563a32f --- /dev/null +++ b/engine/src/lua_setup_stack.cpp @@ -0,0 +1,44 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include "lua_api.hpp" +#include "lua_setup.hpp" +#include "stack.hpp" + +namespace cloe { + +void register_usertype_stack(sol::table& lua) { + auto stack = lua.new_usertype("Stack", sol::no_constructor); + stack["active_config"] = [](Stack& self, sol::this_state lua) { + return fable::into_sol_object(lua, self.active_config()); + }; + stack["input_config"] = [](Stack& self, sol::this_state lua) { + return fable::into_sol_object(lua, self.input_config()); + }; + stack["merge_stackfile"] = &Stack::merge_stackfile; + stack["merge_stackjson"] = [](Stack& self, const std::string& json, std::string file) { + self.from_conf(Conf{fable::parse_json(json), std::move(file)}); + }; + stack["merge_stacktable"] = [](Stack& self, sol::object obj, std::string file) { + self.from_conf(Conf{Json(obj), std::move(file)}); + }; +} + +} // namespace cloe diff --git a/engine/src/lua_setup_sync.cpp b/engine/src/lua_setup_sync.cpp new file mode 100644 index 000000000..187fdb2a4 --- /dev/null +++ b/engine/src/lua_setup_sync.cpp @@ -0,0 +1,39 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "lua_setup.hpp" + +#include + +namespace cloe { + +void register_usertype_sync(sol::table& lua) { + lua.new_usertype<::cloe::Sync> + ("Sync", + sol::no_constructor, + "step", &Sync::step, + "step_width", &Sync::step_width, + "time", &Sync::time, + "eta", &Sync::eta, + "realtime_factor", &Sync::realtime_factor, + "is_realtime_factor_unlimited", &Sync::is_realtime_factor_unlimited, + "achievable_realtime_factor", &Sync::achievable_realtime_factor + ); +} + +} // namespace cloe diff --git a/engine/src/lua_stack_test.cpp b/engine/src/lua_stack_test.cpp new file mode 100644 index 000000000..9022fcf8a --- /dev/null +++ b/engine/src/lua_stack_test.cpp @@ -0,0 +1,91 @@ + +#include +#include +#include + +#include + +#include // for Json +#include +#include // for assert_from_conf +#include + +#include "lua_setup.hpp" +#include "stack.hpp" // for Stack +using namespace cloe; // NOLINT(build/namespaces) + +TEST(cloe_lua_stack, deserialize_vehicle_conf) { + sol::state lua; + + lua.open_libraries(sol::lib::base); + + lua.script("from = { index = 0, simulator = \"nop\" }"); + lua.script("print(from.index)"); + + cloe::FromSimulator fromsim; + sol::object obj = lua["from"]; + cloe::Json json(obj); + try { + fromsim.from_conf(Conf{json}); + } catch (fable::SchemaError& err) { + fable::pretty_print(err, std::cerr); + FAIL(); + } +} + +TEST(cloe_lua_stack, convert_json_lua) { + sol::state lua; + lua.open_libraries(sol::lib::base); + + Stack s; + + lua["stack"] = fable::into_sol_object(lua, s.active_config()); + lua.script(R"( + assert(stack) + assert(stack.version == "4.1") + assert(stack.engine) + )"); +} + +TEST(cloe_lua_stack, copy_stack_json_lua) { + sol::state lua; + lua.open_libraries(sol::lib::base); + + Stack s1; + s1.engine.keep_alive = true; // change something + Stack s2 = s1; // copy + + lua["s1"] = fable::into_sol_object(lua, s1.active_config()); + lua["s2"] = fable::into_sol_object(lua, s2.active_config()); + lua.script(R"( + assert(s1) + assert(s1.version == "4.1") + assert(s1.engine) + )"); + + lua.script(R"( + function deep_equal(a, b) + if a == b then + return true + end + if type(a) ~= type(b) then + return false + end + if type(a) == 'table' then + for k, v in pairs(a) do + if not deep_equal(v, b[k]) then + return false + end + end + for k, _ in pairs(b) do + if a[k] == nil then + return false + end + end + return true + end + return false + end + assert(deep_equal(s1, s2)) + )"); +} diff --git a/engine/src/main.cpp b/engine/src/main.cpp index 695fdd997..5dda6490b 100644 --- a/engine/src/main.cpp +++ b/engine/src/main.cpp @@ -84,6 +84,15 @@ int main(int argc, char** argv) { // One of the above subcommands must be used. app.require_subcommand(); + // Shell Command: + engine::ShellOptions shell_options{}; + std::vector shell_files{}; + auto* shell = app.add_subcommand("shell", "Start a Lua shell."); + shell->add_flag("-i,--interactive,!--no-interactive", shell_options.interactive, + "Drop into interactive mode (default)"); + shell->add_option("-c,--command", shell_options.commands, "Lua to run after running files"); + shell->add_option("files", shell_files, "Lua files to run before starting the shell"); + // Global Options: std::string log_level = "warn"; app.set_help_all_flag("-H,--help-all", "Print all help messages and exit"); @@ -110,11 +119,18 @@ int main(int argc, char** argv) { app.add_flag("--interpolate-undefined", stack_options.interpolate_undefined, "Interpolate undefined variables with empty strings"); + cloe::LuaOptions lua_options{}; + lua_options.environment = stack_options.environment; + app.add_option("--lua-path", lua_options.lua_paths, + "Scan directory for lua files when loading modules (Env:CLOE_LUA_PATH)"); + app.add_flag("--no-system-lua", lua_options.no_system_lua, "Disable default Lua system paths"); + // The --strict flag here is useful for all our smoketests, since this is the // combination of flags we use for maximum reproducibility / isolation. // Note: This option also affects / overwrites options for the run subcommand! - app.add_flag("-t,--strict,!--no-strict", stack_options.strict_mode, - "Forces flags: --no-system-plugins --no-system-confs --require-success") + app.add_flag( + "-t,--strict,!--no-strict", stack_options.strict_mode, + "Forces flags: --no-system-plugins --no-system-confs --no-system-lua --require-success") ->envname("CLOE_STRICT_MODE"); app.add_flag("-s,--secure,!--no-secure", stack_options.secure_mode, "Forces flags: --strict --no-hooks --no-interpolate") @@ -147,6 +163,7 @@ int main(int argc, char** argv) { if (stack_options.strict_mode) { stack_options.no_system_plugins = true; stack_options.no_system_confs = true; + lua_options.no_system_lua = true; run_options.require_success = true; } stack_options.environment->prefer_external(false); @@ -156,6 +173,7 @@ int main(int argc, char** argv) { auto with_global_options = [&](auto& opt) -> decltype(opt) { std::swap(opt.stack_options, stack_options); + std::swap(opt.lua_options, lua_options); return opt; }; @@ -172,6 +190,8 @@ int main(int argc, char** argv) { return engine::check(with_global_options(check_options), check_files); } else if (*run) { return engine::run(with_global_options(run_options), run_files); + } else if (*shell) { + return engine::shell(with_global_options(shell_options), shell_files); } } catch (cloe::ConcludedError& e) { return EXIT_FAILURE; diff --git a/engine/src/main_commands.hpp b/engine/src/main_commands.hpp index 94aac1d65..fcdb60a9c 100644 --- a/engine/src/main_commands.hpp +++ b/engine/src/main_commands.hpp @@ -25,12 +25,14 @@ #include // for optional<> +#include "lua_setup.hpp" #include "stack_factory.hpp" namespace engine { struct CheckOptions { cloe::StackOptions stack_options; + cloe::LuaOptions lua_options; std::ostream* output = &std::cout; std::ostream* error = &std::cerr; @@ -46,6 +48,7 @@ int check(const CheckOptions& opt, const std::vector& filepaths); struct DumpOptions { cloe::StackOptions stack_options; + cloe::LuaOptions lua_options; std::ostream* output = &std::cout; std::ostream* error = &std::cerr; @@ -58,6 +61,7 @@ int dump(const DumpOptions& opt, const std::vector& filepaths); struct RunOptions { cloe::StackOptions stack_options; + cloe::LuaOptions lua_options; std::ostream* output = &std::cout; std::ostream* error = &std::cerr; @@ -77,6 +81,7 @@ int run(const RunOptions& opt, const std::vector& filepaths); struct ShellOptions { cloe::StackOptions stack_options; + cloe::LuaOptions lua_options; std::ostream* output = &std::cout; std::ostream* error = &std::cerr; @@ -93,6 +98,7 @@ int shell(const ShellOptions& opt, const std::vector& filepaths); struct UsageOptions { cloe::StackOptions stack_options; + cloe::LuaOptions lua_options; std::ostream* output = &std::cout; std::ostream* error = &std::cerr; diff --git a/engine/src/main_run.cpp b/engine/src/main_run.cpp index 06a81e5de..6773624e7 100644 --- a/engine/src/main_run.cpp +++ b/engine/src/main_run.cpp @@ -34,7 +34,7 @@ #include // for read_conf #include "error_handler.hpp" // for conclude_error -#include "main_commands.hpp" // for RunOptions, new_stack +#include "main_commands.hpp" // for RunOptions, new_stack, new_lua #include "simulation.hpp" // for Simulation, SimulationResult #include "stack.hpp" // for Stack @@ -63,10 +63,15 @@ int run(const RunOptions& opt, const std::vector& filepaths) { // Load the stack file: cloe::Stack stack = cloe::new_stack(opt.stack_options); + sol::state lua = cloe::new_lua(opt.lua_options, stack); try { cloe::conclude_error(*opt.stack_options.error, [&]() { for (const auto& file : filepaths) { - cloe::merge_stack(opt.stack_options, stack, file); + if (boost::algorithm::ends_with(file, ".lua")) { + cloe::merge_lua(lua, file); + } else { + cloe::merge_stack(opt.stack_options, stack, file); + } } if (!opt.allow_empty) { stack.check_completeness(); @@ -77,7 +82,7 @@ int run(const RunOptions& opt, const std::vector& filepaths) { } // Create simulation: - Simulation sim(std::move(stack), uuid); + Simulation sim(std::move(stack), std::move(lua), uuid); GLOBAL_SIMULATION_INSTANCE = ∼ std::ignore = std::signal(SIGINT, handle_signal); diff --git a/engine/src/main_shell.cpp b/engine/src/main_shell.cpp new file mode 100644 index 000000000..732f2cbe4 --- /dev/null +++ b/engine/src/main_shell.cpp @@ -0,0 +1,191 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include // for cout, cerr, endl +#include // for optional<> +#include // for string +#include // for make_pair<> +#include // for vector<> + +#include + +#include +#include // for ends_with + +#include "lua_api.hpp" // for lua_safe_script_file +#include "main_commands.hpp" // for Stack, new_stack, LuaOptions, new_lua +#include "stack.hpp" // for Stack + +namespace engine { + +template +void print_error(std::ostream& os, const S& chunk) { + auto err = sol::error(chunk); + os << sol::to_string(chunk.status()) << " error: " << err.what() << std::endl; +} + +bool evaluate(sol::state& lua, std::ostream& os, const char* buf) { + try { + auto result = lua.safe_script(buf, sol::script_pass_on_error); + if (!result.valid()) { + print_error(os, result); + return false; + } + } catch (const std::exception& e) { + os << "runtime error: " << e.what() << std::endl; + return false; + } + return true; +} + +int noninteractive_shell(sol::state& lua, std::ostream& os, const std::vector& actions, + bool ignore_errors) { + int errors = 0; + for (const auto& action : actions) { + auto ok = evaluate(lua, os, action.c_str()); + if (!ok) { + errors++; + if (!ignore_errors) { + break; + } + } + } + return errors; +} + +void interactive_shell(sol::state& lua, std::ostream& os, const std::vector& actions, + bool ignore_errors) { + constexpr auto PROMPT = "> "; + constexpr auto PROMPT_CONTINUE = ">> "; + constexpr auto PROMPT_HISTORY = "< "; + constexpr auto HISTORY_LENGTH = 1024; + + // Set up linenoise library + linenoiseSetMultiLine(1); + linenoiseHistorySetMaxLen(HISTORY_LENGTH); + + os << "Cloe " << CLOE_ENGINE_VERSION << " Lua interactive shell" << std::endl; + os << "Press [Ctrl+D] or [Ctrl+C] to exit." << std::endl; + + // Run actions from command line first + auto remaining = actions.size(); + for (const auto& action : actions) { + os << PROMPT_HISTORY << action << std::endl; + linenoiseHistoryAdd(action.c_str()); + remaining--; + auto ok = evaluate(lua, os, action.c_str()); + if (!ok && !ignore_errors) { + break; + } + } + if (remaining != 0) { + os << "warning: dropping to interactive console early due to error" << std::endl; + } + + // Export describe into global namespace + lua.safe_script(R"#( function describe(obj) print(require("inspect").inspect(obj)) end )#"); + + // Start REPL loop + std::string buf; + std::string vbuf; + for (;;) { + auto* line = linenoise(buf.empty() ? PROMPT : PROMPT_CONTINUE); + if (line == nullptr) { + break; + } + buf += line; + linenoiseFree(line); + + sol::load_result chunk; + { + // Enable return value printing by injecting "return"; + // if it does not parse, then we abort and use original value. + vbuf = "return " + buf; + chunk = lua.load(vbuf); + if (!chunk.valid()) { + chunk = lua.load(buf); + } + } + + if (!chunk.valid()) { + auto err = sol::error(chunk); + if (fable::ends_with(err.what(), "near ")) { + // The following error messages seem to indicate + // that Lua is just waiting for more to complete the statement: + // + // 'end' expected near + // unexpected symbol near + // or '...' expected near + // + // In this case, we don't clear buf, but instead allow + // the user to continue inputting on the next line. + buf += " "; + continue; + } + print_error(os, chunk); + buf.clear(); + continue; + } + + auto script = chunk.get(); + auto result = script(); + if (!result.valid()) { + print_error(os, result); + } else if (result.return_count() > 0) { + for (auto r : result) { + lua["describe"](r); + } + } + + // Clear buf for next input line + linenoiseHistoryAdd(buf.c_str()); + buf.clear(); + } +} + +int shell(const ShellOptions& opt, const std::vector& filepaths) { + assert(opt.output != nullptr && opt.error != nullptr); + + cloe::StackOptions stack_opt = opt.stack_options; + cloe::Stack stack = cloe::new_stack(stack_opt); + auto lopt = opt.lua_options; + lopt.auto_require_cloe = true; + sol::state lua = cloe::new_lua(lopt, stack); + + // Collect input files and strings to execute + std::vector actions{}; + actions.reserve(filepaths.size() + opt.commands.size()); + for (const auto& file : filepaths) { + actions.emplace_back(fmt::format("dofile(\"{}\")", file)); + } + actions.insert(actions.end(), opt.commands.begin(), opt.commands.end()); + + // Determine whether we should be interactive or not + bool interactive = opt.interactive ? *opt.interactive : opt.commands.empty() && filepaths.empty(); + if (!interactive) { + auto errors = noninteractive_shell(lua, *opt.error, actions, opt.ignore_errors); + if (errors != 0) { + return EXIT_FAILURE; + } + } else { + interactive_shell(lua, *opt.output, actions, opt.ignore_errors); + } + return EXIT_SUCCESS; +} + +} // namespace engine diff --git a/engine/src/registrar.hpp b/engine/src/registrar.hpp index 948fc6cff..bba820b1d 100644 --- a/engine/src/registrar.hpp +++ b/engine/src/registrar.hpp @@ -33,9 +33,8 @@ namespace engine { class Registrar : public cloe::Registrar { public: - Registrar(std::unique_ptr r, std::shared_ptr c) - : server_registrar_(std::move(r)) - , coordinator_(std::move(c)) {} + Registrar(std::unique_ptr r, Coordinator* c) + : server_registrar_(std::move(r)), coordinator_(c) {} Registrar(const Registrar& ar, const std::string& trigger_prefix, @@ -104,9 +103,13 @@ class Registrar : public cloe::Registrar { coordinator_->register_event(trigger_key(ef->name()), std::move(ef), storage); } + sol::table register_lua_table() override { + return coordinator_->register_lua_table(trigger_prefix_); + } + private: std::unique_ptr server_registrar_; - std::shared_ptr coordinator_; + Coordinator* coordinator_; // non-owning std::string trigger_prefix_; }; diff --git a/engine/src/simulation.cpp b/engine/src/simulation.cpp index 3031815fd..81ac83082 100644 --- a/engine/src/simulation.cpp +++ b/engine/src/simulation.cpp @@ -93,7 +93,11 @@ #include // for INCLUDE_RESOURCE, RESOURCE_HANDLER #include // for Vehicle #include // for pretty_print +#include // for sol::object to_json +#include "coordinator.hpp" // for register_usertype_coordinator +#include "lua_action.hpp" // for LuaAction, +#include "lua_api.hpp" // for to_json(json, sol::object) #include "simulation_context.hpp" // for SimulationContext #include "utility/command.hpp" // for CommandFactory #include "utility/state_machine.hpp" // for State, StateMachine @@ -324,8 +328,7 @@ StateId SimulationMachine::Connect::impl(SimulationContext& ctx) { ctx.server->refresh_buffer(); }; - { - // 2. Initialize loggers + { // 2. Initialize loggers update_progress("logging"); for (const auto& c : ctx.config.logging) { @@ -333,8 +336,14 @@ StateId SimulationMachine::Connect::impl(SimulationContext& ctx) { } } - { - // 3. Enroll endpoints and triggers for the server + { // 3. Initialize Lua + auto types_tbl = sol::object(cloe::luat_cloe_engine_types(ctx.lua)).as(); + register_usertype_coordinator(types_tbl, ctx.sync); + + cloe::luat_cloe_engine_state(ctx.lua)["scheduler"] = std::ref(*ctx.coordinator); + } + + { // 4. Enroll endpoints and triggers for the server update_progress("server"); auto rp = ctx.simulation_registrar(); @@ -428,6 +437,7 @@ StateId SimulationMachine::Connect::impl(SimulationContext& ctx) { r.register_action(&ctx.sync); r.register_action(&ctx.statistics); r.register_action(ctx.commander.get()); + r.register_action(ctx.lua); // From: cloe/trigger/example_actions.hpp auto tr = ctx.coordinator->trigger_registrar(cloe::Source::TRIGGER); @@ -437,8 +447,7 @@ StateId SimulationMachine::Connect::impl(SimulationContext& ctx) { r.register_action(tr); } - { - // 4. Initialize simulators + { // 5. Initialize simulators update_progress("simulators"); /** @@ -482,8 +491,7 @@ StateId SimulationMachine::Connect::impl(SimulationContext& ctx) { cloe::handler::StaticJson(ctx.simulator_ids())); } - { - // 5. Initialize vehicles + { // 6. Initialize vehicles update_progress("vehicles"); /** @@ -673,8 +681,7 @@ StateId SimulationMachine::Connect::impl(SimulationContext& ctx) { cloe::handler::StaticJson(ctx.vehicle_ids())); } - { - // 6. Initialize controllers + { // 7. Initialize controllers update_progress("controllers"); /** @@ -759,6 +766,7 @@ StateId SimulationMachine::Start::impl(SimulationContext& ctx) { // Process initial trigger list insert_triggers_from_config(ctx); + ctx.coordinator->process_pending_lua_triggers(ctx.sync); ctx.coordinator->process(ctx.sync); ctx.callback_start->trigger(ctx.sync); @@ -1197,16 +1205,19 @@ StateId SimulationMachine::Abort::impl(SimulationContext& ctx) { // --------------------------------------------------------------------------------------------- // -Simulation::Simulation(const cloe::Stack& config, const std::string& uuid) - : logger_(cloe::logger::get("cloe")), config_(config), uuid_(uuid) {} +Simulation::Simulation(cloe::Stack&& config, sol::state&& lua, const std::string& uuid) + : config_(std::move(config)) + , lua_(std::move(lua)) + , logger_(cloe::logger::get("cloe")) + , uuid_(uuid) {} SimulationResult Simulation::run() { // Input: - SimulationContext ctx; + SimulationContext ctx{lua_.lua_state()}; ctx.server = make_server(config_.server); - ctx.coordinator.reset(new Coordinator{}); - ctx.registrar.reset(new Registrar{ctx.server->server_registrar(), ctx.coordinator}); - ctx.commander.reset(new CommandExecuter{logger()}); + ctx.coordinator = std::make_unique(ctx.lua); + ctx.registrar = std::make_unique(ctx.server->server_registrar(), ctx.coordinator.get()); + ctx.commander = std::make_unique(logger()); ctx.sync = SimulationSync(config_.simulation.model_step_width); ctx.config = config_; ctx.uuid = uuid_; @@ -1313,6 +1324,7 @@ SimulationResult Simulation::run() { r.statistics = ctx.statistics; r.elapsed = ctx.progress.elapsed(); r.triggers = ctx.coordinator->history(); + r.report = sol::object(cloe::luat_cloe_engine_state(ctx.lua)["report"]); abort_fn_ = nullptr; return r; diff --git a/engine/src/simulation.hpp b/engine/src/simulation.hpp index ac3de4edd..ce2105861 100644 --- a/engine/src/simulation.hpp +++ b/engine/src/simulation.hpp @@ -28,6 +28,7 @@ #include // for path #include // for ENUM_SERIALIZATION +#include // for state #include "simulation_context.hpp" #include "stack.hpp" // for Stack @@ -44,6 +45,7 @@ struct SimulationResult { std::vector errors; SimulationStatistics statistics; cloe::Json triggers; + cloe::Json report; boost::optional output_dir; public: @@ -97,6 +99,7 @@ struct SimulationResult { {"elapsed", r.elapsed}, {"errors", r.errors}, {"outcome", r.outcome}, + {"report", r.report}, {"simulation", r.sync}, {"statistics", r.statistics}, {"uuid", r.uuid}, @@ -106,7 +109,7 @@ struct SimulationResult { class Simulation { public: - Simulation(const cloe::Stack& config, const std::string& uuid); + Simulation(cloe::Stack&& config, sol::state&& lua, const std::string& uuid); ~Simulation() = default; /** @@ -151,6 +154,7 @@ class Simulation { private: cloe::Stack config_; + sol::state lua_; cloe::Logger logger_; std::string uuid_; std::function abort_fn_; diff --git a/engine/src/simulation_context.hpp b/engine/src/simulation_context.hpp index d64bb2268..9bdbc36ff 100644 --- a/engine/src/simulation_context.hpp +++ b/engine/src/simulation_context.hpp @@ -30,6 +30,8 @@ #include // for string #include // for vector<> +#include // for state_view + #include // for Simulator, Controller, Registrar, Vehicle, Duration #include // for Sync #include // for DEFINE_NIL_EVENT @@ -194,6 +196,10 @@ DEFINE_NIL_EVENT(Loop, "loop", "begin of inner simulation loop each cycle") * performed in the simulation states in the `simulation.cpp` file. */ struct SimulationContext { + SimulationContext(sol::state_view&& l) : lua(l) {} + + sol::state_view lua; + // Setup std::unique_ptr server; std::shared_ptr coordinator; diff --git a/fable/conanfile.py b/fable/conanfile.py index 503b19740..25f2ff040 100644 --- a/fable/conanfile.py +++ b/fable/conanfile.py @@ -52,7 +52,7 @@ def requirements(self): def build_requirements(self): self.test_requires("gtest/1.14.0") self.test_requires("boost/1.74.0") - self.test_requires("sol2/3.3.0") + self.test_requires("sol2/3.3.1") def layout(self): cmake.cmake_layout(self) diff --git a/plugins/basic/src/basic.cpp b/plugins/basic/src/basic.cpp index 36ca03fac..27baf6648 100644 --- a/plugins/basic/src/basic.cpp +++ b/plugins/basic/src/basic.cpp @@ -31,6 +31,7 @@ #include // for vector<> #include // for Schema +#include #include // for DriverRequest #include // for LatLongActuator @@ -395,6 +396,40 @@ class BasicController : public Controller { } void enroll(Registrar& r) override { + auto lua = r.register_lua_table(); + + { + auto acc = lua.new_usertype("AccConfiguration", sol::no_constructor); + acc["ego_sensor"] = sol::readonly(&AccConfiguration::ego_sensor); + acc["world_sensor"] = sol::readonly(&AccConfiguration::world_sensor); + acc["latlong_actuator"] = sol::readonly(&AccConfiguration::latlong_actuator); + acc["limit_acceleration"] = &AccConfiguration::limit_acceleration; + acc["limit_deceleration"] = &AccConfiguration::limit_deceleration; + acc["derivative_factor_speed_control"] = &AccConfiguration::kd; + acc["proportional_factor_speed_control"] = &AccConfiguration::kp; + acc["integral_factor_speed_control"] = &AccConfiguration::ki; + acc["derivative_factor_dist_control"] = &AccConfiguration::kd_m; + acc["proportional_factor_dist_control"] = &AccConfiguration::kp_m; + acc["integral_factor_dist_control"] = &AccConfiguration::ki_m; + + auto inst = lua.create("acc"); + inst["config"] = std::ref(acc_.config); + inst["enabled"] = &acc_.enabled; + inst["active"] = &acc_.active; + inst["distance_algorithm"] = sol::property( + [this]() -> std::string { return distance::ALGORITHMS[acc_.distance_algorithm].first; }, + [this](const std::string& name) { + for (size_t i = 0; i < distance::ALGORITHMS.size(); i++) { + if (distance::ALGORITHMS[i].first == name) { + acc_.distance_algorithm = i; + return; + } + } + // FIXME: Throw an error here + }); + inst["target_speed"] = &acc_.target_speed; + } + r.register_action(std::make_unique>(&hmi_)); // clang-format off diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index d49101630..771837e16 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -21,6 +21,7 @@ if(NOT TARGET pantor::inja) add_library(pantor::inja ALIAS inja::inja) endif() find_package(incbin REQUIRED QUIET) +find_package(sol2 REQUIRED QUIET) file(GLOB cloe-runtime_PUBLIC_HEADERS "include/**/*.hpp") message(STATUS "Building cloe-runtime library.") @@ -70,6 +71,7 @@ target_link_libraries(cloe-runtime Boost::system fable::fable spdlog::spdlog + sol2::sol2 INTERFACE pantor::inja incbin::incbin @@ -79,6 +81,7 @@ target_compile_definitions(cloe-runtime PROJECT_SOURCE_DIR=\"${CMAKE_CURRENT_SOURCE_DIR}\" PUBLIC _USE_MATH_DEFINES=1 + SOL_ALL_SAFETIES_ON=1 ) # Testing ------------------------------------------------------------- @@ -103,6 +106,7 @@ if(BUILD_TESTING) GTest::gtest GTest::gtest_main Boost::boost + sol2::sol2 cloe-runtime ) gtest_add_tests(TARGET test-cloe) diff --git a/runtime/conanfile.py b/runtime/conanfile.py index 9cde7269f..87fa99f05 100644 --- a/runtime/conanfile.py +++ b/runtime/conanfile.py @@ -49,6 +49,7 @@ def requirements(self): self.requires("inja/3.4.0") self.requires("spdlog/1.11.0") self.requires("incbin/cci.20211107") + self.requires("sol2/3.3.1") def build_requirements(self): self.test_requires("gtest/1.14.0") diff --git a/runtime/include/cloe/registrar.hpp b/runtime/include/cloe/registrar.hpp index f96065318..1c6aa49ab 100644 --- a/runtime/include/cloe/registrar.hpp +++ b/runtime/include/cloe/registrar.hpp @@ -26,6 +26,7 @@ #include // for string #include // for move +#include #include // for Handler #include // for ActionFactory, EventFactory, Callback, ... #include // for Json @@ -197,6 +198,11 @@ class Registrar { */ virtual void register_event(std::unique_ptr&& f, std::shared_ptr c) = 0; + /** + * Provide a Lua table for registration of functions and variables. + */ + virtual sol::table register_lua_table() = 0; + /** * Register an EventFactory and return a DirectCallback for storage of * events. diff --git a/runtime/include/cloe/trigger.hpp b/runtime/include/cloe/trigger.hpp index af832fc72..d78886c10 100644 --- a/runtime/include/cloe/trigger.hpp +++ b/runtime/include/cloe/trigger.hpp @@ -365,6 +365,9 @@ enum class Source { /// Triggers that are instance of a sticky trigger. INSTANCE, + + /// Triggers that originate from a Lua script + LUA, }; // clang-format off @@ -374,6 +377,7 @@ ENUM_SERIALIZATION(Source, ({ {Source::MODEL, "model"}, {Source::TRIGGER, "trigger"}, {Source::INSTANCE, "instance"}, + {Source::LUA, "lua"}, })) // clang-format on @@ -386,7 +390,7 @@ ENUM_SERIALIZATION(Source, ({ * reproduction. */ inline bool source_is_transient(Source s) { - return (s != Source::FILESYSTEM && s != Source::NETWORK); + return (s != Source::FILESYSTEM && s != Source::NETWORK && s != Source::LUA); } /** diff --git a/tests/project.lua b/tests/project.lua new file mode 100644 index 000000000..bdd01f352 --- /dev/null +++ b/tests/project.lua @@ -0,0 +1,256 @@ +-- This file configures the "project", as it were. +-- +-- It more or less does the exact same as config_nop_smoketest.json, +-- except in a more configurable, modular way. +-- +-- You will notice that this file follows the same format as a Lua +-- module, and indeed, this is how it is expected to be used: +-- +-- local project = require("project") +-- project.configure_all { +-- with_server = false, +-- with_noisy_sensor = true, +-- } +-- +-- In order to be maximally useful to users, each of the defined +-- functions should be documented so that the Lua Language Server +-- can give the users auto-completion and documentation hints. +local api = require("cloe-engine") +local cloe = require("cloe") +local luax = require("cloe.luax") +local shapes = require("tableshape").types +local system = require("cloe.system") + +local m = {} + +--- Initialize project specific report metadata. +--- +--- @param ... table +--- @return nil +function m.init_report(...) + local results = {} + local file = api.state.current_script_file + if file then + results["source"] = cloe.fs.realpath(file) + end + results["datetime"] = system.get_datetime() + for _, tbl in ipairs({ ... }) do + results = luax.tbl_extend("force", results, tbl) + end + + api.state.report.metadata = results +end + +--- Apply a stackfile, setting version to "4". +--- +--- These lines modify the input stack to make it +--- conform to the current stackfile version. +--- Call it a minimal "quality-of-life improvement". +--- +--- @param spec table Stack input table +--- @return nil -- Return value of cloe.apply_stack +m.apply_stack = function(spec) + cloe.validate("project.apply_stack(table)", spec) + spec.version = "4" + return cloe.apply_stack(spec) +end + +--- @class ProjectOptions +--- @field with_server? boolean +--- @field with_noisy_sensor? boolean +local ProjectOptions = shapes.shape({ + with_server = shapes.boolean:is_optional(), + with_noisy_sensor = shapes.boolean:is_optional(), +}) + +--- Configure all aspects of the simulation. +--- +--- This just calls the other configure_* methods in a convenient way. +--- +--- @param opts ProjectOptions +--- @return nil +m.configure_all = function(opts) + opts = opts or {} + cloe.validate_shape("project.configure_all(ProjectOptions)", ProjectOptions, opts) + + local vehname = "default" + local simname = "nop" + + m.configure_nop_simulator(simname) + m.configure_vehicle(vehname, simname, { + with_noisy_sensor = opts.with_noisy_sensor, + }) + m.configure_server(opts.with_server) + m.configure_virtue(vehname) + m.configure_basic(vehname) +end + +--- Configure the simulator. +--- +--- @param name string Name of simulator (e.g. "sim") +--- @return nil -- Return value of cloe.apply_stack +m.configure_nop_simulator = function(name) + cloe.validate("project.configure_nop_simulator(string)", name) + + return m.apply_stack({ + simulators = { + { binding = "nop", name = name }, + }, + }) +end + +--- @class VehicleOptions +--- @field with_noisy_sensor? boolean +local VehicleOptions = shapes.shape({ + with_noisy_sensor = shapes.boolean:is_optional(), +}) + +--- Configure the vehicle. +--- +--- @param name string Name of the vehicle +--- @param simulator string|table Name of simulator or config block +--- @param opts VehicleOptions +--- @return nil -- Return value of cloe.apply_stack +m.configure_vehicle = function(name, simulator, opts) + cloe.validate("project.configure_vehicle(string, string|table, ?table)", name, simulator, opts) + local from = simulator + if type(from) == "string" then + from = { + simulator = simulator, + index = 0, + } + end + opts = opts or {} + cloe.validate_shape("project.configure_vehicle(string, string|table, VehicleOptions)", VehicleOptions, opts) + + local components = { + ["cloe::speedometer"] = { + binding = "speedometer", + name = "default_speed", + from = "cloe::gndtruth_ego_sensor", + }, + } + if opts.with_noisy_sensor then + components["cloe::default_world_sensor"] = { + binding = "noisy_object_sensor", + name = "noisy_object_sensor", + from = "cloe::default_world_sensor", + args = { + noise = { + { + target = "translation", + distribution = { + binding = "normal", + mean = 0.0, + std_deviation = 0.3, + }, + }, + }, + }, + } + end + + if from.simulator == "nop" then + cloe.schedule({ + desc = "Vehicle should never move with nop binding", + on = "default_speed/kmph=>0.0", + run = "fail", + }) + end + + return m.apply_stack({ + vehicles = { + { name = name, from = from, components = components }, + }, + }) +end + +--- Configure server if possible. +--- +--- @param enable boolean +--- @return nil +m.configure_server = function(enable) + if enable then + -- Query system to see if something is already listening on the port. + local code = os.execute("ss -H -l 'sport = 23456' | grep tcp") + if code == 0 then + cloe.log("error", "a process is already listening at 23456") + enable = false + end + end + return m.apply_stack({ + server = { + listen = enable, + listen_port = 23456, + }, + }) +end + +--- Configure virtue controller for vehicle. +--- +--- @param vehicle string Vehicle name +--- @return nil +m.configure_virtue = function(vehicle) + m.apply_stack({ + controllers = { + { binding = "virtue", vehicle = vehicle }, + }, + }) + cloe.schedule({ on = "virtue/failure", run = "fail" }) +end + +--- Configure basic controller for vehicle. +--- +--- @param vehicle string Vehicle name +--- @return nil +m.configure_basic = function(vehicle) + m.apply_stack({ + controllers = { + { binding = "basic", vehicle = vehicle }, + }, + }) + cloe.schedule_these({ + { on = "start", run = "basic/hmi=!enable" }, + { on = "next=1", run = "basic/hmi=enable" }, + { on = "time=5", run = "basic/hmi=resume" }, + { on = "time=5.5", run = "basic/hmi=!resume" }, + }) +end + +--- Set realtime factor. +--- +--- @param factor number Use -1 for maximum speed, 1.0 for realtime +--- @return nil +m.set_realtime_factor = function(factor) + if factor == 0 then + error("cannot set realtime factor of 0") + end + cloe.schedule({ on = "start", run = "realtime_factor=" .. tostring(factor) }) +end + +--- Do an action after given duration. +--- +--- @param duration string Duration with unit of time, e.g. "5s" or "5000ms" +--- @param action string|function Anything that can be scheduled +--- @return nil +m.action_after = function(duration, action) + local dur = cloe.Duration.new(duration) + return cloe.schedule({ on = "next=" .. dur:s(), run = action }) +end + +--- Fail after this amount of time. +m.fail_after = function(duration) + return m.action_after(duration, "fail") +end + +--- Succeed after this amount of time. +m.succeed_after = function(duration) + return m.action_after(duration, "succeed") +end + +--- Stop after this amount of time. +m.stop_after = function(duration) + return m.action_after(duration, "stop") +end + +return m diff --git a/tests/report_config.lua b/tests/report_config.lua new file mode 100644 index 000000000..fec9c1b4a --- /dev/null +++ b/tests/report_config.lua @@ -0,0 +1,7 @@ +return { + hostname = "node01", + username = "jenkins", + device_under_test = {}, + simulator = {}, + docker_version = "24.0.5", +} diff --git a/tests/test_engine_json_schema.json b/tests/test_engine_json_schema.json index c2149ffdd..bdbcbdadb 100644 --- a/tests/test_engine_json_schema.json +++ b/tests/test_engine_json_schema.json @@ -1413,7 +1413,8 @@ "network", "model", "trigger", - "instance" + "instance", + "lua" ], "type": "string" }, diff --git a/tests/test_engine_lua.json b/tests/test_engine_lua.json new file mode 100644 index 000000000..9cabbaf34 --- /dev/null +++ b/tests/test_engine_lua.json @@ -0,0 +1,15 @@ +{ + "version": "4", + "include": [ + "config_nop_smoketest.json" + ], + "triggers": [ + { + "event": "time=5", + "action": { + "name": "lua", + "script": "cloe.log_info(\"Hello world! Step is \" .. tostring(cloe.step()))" + } + } + ] +} diff --git a/tests/test_lua.bats b/tests/test_lua.bats new file mode 100755 index 000000000..cd8dc2579 --- /dev/null +++ b/tests/test_lua.bats @@ -0,0 +1,107 @@ +#!/usr/bin/env bats + +load setup_bats +load setup_testname + +@test "$(testname 'Expect success' 'test_lua01_include_json.lua' '224b2b67-1aaf-4ba2-855c-9bf986574e30')" { + cloe-engine run test_lua01_include_json.lua +} + +@test "$(testname 'Expect success' 'test_lua02_schedule.lua' '93053c17-af8d-461d-b457-e6722e857306')" { + cloe-engine run --allow-empty test_lua02_schedule.lua +} + +@test "$(testname 'Expect success' 'test_lua03_schedule_unpin.lua' '93fe7665-688a-48f5-bd66-4a20a6711ce9')" { + cloe-engine run test_lua03_schedule_unpin.lua +} + +@test "$(testname 'Expect success' 'test_lua04_schedule_test.lua' 'e03fc31f-586b-4e57-80fa-ff2cba5ff9dd')" { + cloe-engine run test_lua04_schedule_test.lua +} + +@test "$(testname 'Expect success' 'test_lua05_apply_stack.lua' 'bbee495e-8e19-4ffb-912f-fa75840a7944')" { + cloe-engine run test_lua05_apply_stack.lua +} + +@test "$(testname 'Expect success' 'test_lua06_apply_stack.lua' 'ba5b7fbd-3b47-4767-b7b2-1075bdaa736f')" { + cloe-engine run test_lua06_apply_stack.lua +} + +@test "$(testname 'Expect success' 'test_lua07_schedule_pause.lua' '41ece52e-146a-414d-b93d-4bc4512c49b8')" { + require_program netcat + require_program curl + + cloe-engine run test_lua07_schedule_pause.lua +} + +@test "$(testname 'Expect success' 'test_lua08_apply_project.lua' '037010ed-7b08-4874-94bd-27d959bdfaca')" { + cloe-engine run test_lua08_apply_project.lua +} + +@test "$(testname 'Expect success' 'test_lua08_apply_project.lua (2)' '037010ed-7b08-4874-94bd-27d959bdfaca')" { + cd .. + cloe-engine run tests/test_lua08_apply_project.lua +} + +@test "$(testname 'Expect success' 'test_lua09_no_json.lua' '5a0fe683-355c-4584-97ea-fa012f40fa81')" { + cloe-engine run test_lua09_no_json.lua +} + +@test "$(testname 'Expect success' 'test_lua10_heavy_cpu.lua' 'fbf32388-a80e-4fb3-b334-b4cd4f020cdb')" { + cloe-engine run test_lua10_heavy_cpu.lua +} + +@test "$(testname 'Expect success' 'test_lua11_serial_tests.lua' '852edc33-a344-437e-b11d-82527a0ea387')" { + cloe-engine run test_lua11_serial_tests.lua +} + +@test "$(testname 'Expect failure' 'test_lua12_fail_after_stop.lua' '880875e8-b7ad-4d86-abf5-b2cd31b1a1db')" { + run cloe-engine run test_lua12_fail_after_stop.lua + assert_check_failure $status $output +} + +@test "$(testname 'Expect success' 'test_lua13_bdds_eval.lua' 'd7f31aaa-ccab-421b-a9ae-06aa3835018b')" { + cloe-engine run test_lua13_bdd_eval.lua +} + +# --- API --------------------------------------------------------------------- + +@test "$(testname 'Check API' 'test_lua_api_cloe_system.lua' '23496512-a7f9-4fb7-8ed3-a655954b24f7')" { + cloe-engine shell test_lua_api_cloe_system.lua +} + +@test "$(testname 'Check API' 'test_lua_api_cloe_typecheck.lua' 'd10cfa73-c03e-4e3d-876d-60d7c5c0ee73')" { + cloe-engine shell test_lua_api_cloe_typecheck.lua +} + +# --- Better errors ----------------------------------------------------------- + +@test "$(testname 'Expect failure' 'test_lua_error_main.lua' '9cc0c5a4-5771-4cec-befe-ae49bd3e0cae')" { + run cloe-engine run test_lua_error_main.lua + assert_check_failure $status $output + echo "$output" | grep "test_lua_error_main.lua:.*: expect error" +} + +@test "$(testname 'Expect failure' 'test_lua_error_coroutine.lua' '9cc0c5a4-5771-4cec-befe-ae49bd3e0cae')" { + run cloe-engine run test_lua_error_coroutine.lua + assert_check_failure $status $output + echo "$output" | grep "test_lua_error_coroutine.lua:.*: expect error" +} + +@test "$(testname 'Expect failure' 'test_lua_error_schedule.lua' '9cc0c5a4-5771-4cec-befe-ae49bd3e0cae')" { + run cloe-engine run test_lua_error_schedule.lua + assert_check_failure $status $output + echo "$output" | grep "test_lua_error_schedule.lua:.*: expect error" +} + +@test "$(testname 'Expect failure' 'test_lua_error_schedule_test.lua' '9cc0c5a4-5771-4cec-befe-ae49bd3e0cae')" { + run cloe-engine run test_lua_error_schedule_test.lua + assert_check_failure $status $output + echo "$output" | grep "test_lua_error_schedule_test.lua:.*: expect error" +} + +@test "$(testname 'Expect segfault' 'test_lua_error_segfault_on_resume.lua' 'df2ac431-7c4d-4253-a38b-f42e0f58a2b2')" { + run cloe-engine run test_lua_error_segfault_on_resume.lua + assert_check_failure $status $output + echo "$output" | grep "segfault" +} diff --git a/tests/test_lua01_include_json.lua b/tests/test_lua01_include_json.lua new file mode 100644 index 000000000..336f3f100 --- /dev/null +++ b/tests/test_lua01_include_json.lua @@ -0,0 +1,17 @@ +-- Simplest configuration just loads an existing stackfile. +-- + +-- The `cloe` table is loaded by default within cloe-engine. +-- +-- This will do nothing in cloe-engine, but it will give us +-- completion when using a Lua LSP. +local cloe = require("cloe") + +-- cloe.has_feature() lets us check whether we have a certain +-- version or feature of cloe. +assert(cloe.has_feature("cloe-0.21"), "cloe is not recent enough") + +-- cloe.load_stackfile() lets us load a stackfile. This may +-- be dropped or shimmed in the future. +cloe.require_feature("cloe-stackfile-4") +cloe.load_stackfile("config_nop_smoketest.json") diff --git a/tests/test_lua02_schedule.lua b/tests/test_lua02_schedule.lua new file mode 100644 index 000000000..4e4bb3109 --- /dev/null +++ b/tests/test_lua02_schedule.lua @@ -0,0 +1,31 @@ +-- This example shows that you don't actually need any plugins at +-- all to have a simulation. You can simple schedule some tasks. +local cloe = require("cloe") +local events, actions = cloe.events, cloe.actions + +cloe.log("info", "Hello world!"); + +cloe.schedule { + on = events.loop(), + priority = 101, -- higher than the default + pin = false, + run = function(_) + cloe.log("info", "Hello world!") + end +} + +cloe.schedule { + on = events.loop(), + pin = true, + run = function(sync) + cloe.log("info", "Current time is %s", sync:time()) + end +} + +-- If you configure this to terminate at for example, 60 seconds, +-- you may notice that we don't even achieve realtime. This is +-- because of the server. See the next test file. +cloe.schedule { + on = events.time("1s"), + run = actions.succeed(), +} diff --git a/tests/test_lua03_schedule_unpin.lua b/tests/test_lua03_schedule_unpin.lua new file mode 100644 index 000000000..7f167c736 --- /dev/null +++ b/tests/test_lua03_schedule_unpin.lua @@ -0,0 +1,27 @@ +-- This example shows that you don't actually need any plugins at +-- all to have a simulation. You can simple schedule some tasks. +local cloe = require("cloe") +local events = cloe.events + +cloe.load_stackfile("config_nop_smoketest.json") + +cloe.schedule { + on = events.loop(), + priority = 101, -- higher than the default + pin = false, + run = function(_) + cloe.log("info", "Hello world!") + end +} + +cloe.schedule { + on = events.every("1s"), + pin = true, + run = function(sync) + cloe.log("info", "Current time is %s", sync:time()) + if sync:time():s() > 30 then + -- FIXME: Unpin is not working + return false + end + end +} diff --git a/tests/test_lua04_schedule_test.lua b/tests/test_lua04_schedule_test.lua new file mode 100644 index 000000000..6df795995 --- /dev/null +++ b/tests/test_lua04_schedule_test.lua @@ -0,0 +1,43 @@ +local cloe = require("cloe") +local events, actions = cloe.events, cloe.actions + +cloe.load_stackfile("config_nop_smoketest.json") + +-- If schedule_test does not work, then we will keep running until +-- this event triggers and we fail. +cloe.schedule { + on = events.time("5s"), + run = actions.fail(), +} + +cloe.schedule { + on = events.every("1s"), + pin = true, + run = function(sync) + cloe.log("info", "Current time is %s", sync:time()) + end +} + +-- Check that schedule_test works as intended. +cloe.schedule_test { + -- Note that this is the same ID as used in BATS. + id = "e03fc31f-586b-4e57-80fa-ff2cba5ff9dd", + on = events.start(), + terminate = false, + run = function(z, sync) + cloe.log("info", "Entering test") + z:assert(true) + + cloe.log("info", "Asserting something...") + z:assert_eq(sync:time():s(), 0, "time at start is 0s") + + cloe.log("info", "Waiting 1s...") + z:wait_duration("1s") + + z:assert(sync:time() >= cloe.Duration.new("1s"), "yield does not work and the time has not advanced") + z:assert(sync:time() == cloe.Duration.new("1s"), "time has advanced the wrong amount") + + cloe.log("info", "We're good here.") + z:succeed() + end +} diff --git a/tests/test_lua05_apply_stack.lua b/tests/test_lua05_apply_stack.lua new file mode 100644 index 000000000..44dda53ab --- /dev/null +++ b/tests/test_lua05_apply_stack.lua @@ -0,0 +1,67 @@ +local cloe = require("cloe") + +cloe.apply_stack { + version = "4", + include = { + "config_nop_infinite.json", + }, + server = { + listen = false, + listen_port = 23456, + }, +} + +-- All the conditions we want to fail on: +cloe.schedule_these { + run = cloe.actions.fail(), + { on = "virtue/failure" }, + { on = "default_speed/kmph=>0.0" }, + { on = "time=5" }, +} + +-- All the things we want to do on start: +cloe.schedule_these { + on = cloe.events.start(), + { run = "log=info: Running nop/basic smoketest." }, + { run = "realtime_factor=-1" }, +} + +cloe.schedule { + on = cloe.events.loop(), + pin = true, + run = function(sync) + cloe.log(cloe.LogLevel.INFO, "Current time is %s", sync:time()) + end +} + +cloe.schedule_test { + id = "precondition", + on = cloe.events.start(), + run = function(z) + z:printf("hello there") + end +} + +-- Check that schedule_test works as intended. +cloe.schedule_test { + -- Note that this is the same ID as used in BATS. + id = "e03fc31f-586b-4e57-80fa-ff2cba5ff9dd", + on = cloe.events.start(), + terminate = false, + run = function(z, sync) + z:printf("Entering test") + z:assert(true, "true is truthy") + + z:printf("Asserting something...") + z:assert_eq(sync:time():s(), 0, "time at start is 0s") + + z:printf("Waiting 1s...") + z:wait_duration("1s") + + z:assert_ge(sync:time(), cloe.Duration.new("1s"), "yield works and the time has advanced") + z:assert_eq(sync:time(), cloe.Duration.new("1s"), "time has advanced the right amount") + + z:printf("We're good here.") + z:succeed() + end +} diff --git a/tests/test_lua06_apply_stack.lua b/tests/test_lua06_apply_stack.lua new file mode 100644 index 000000000..82c0b1205 --- /dev/null +++ b/tests/test_lua06_apply_stack.lua @@ -0,0 +1,67 @@ +local cloe = require("cloe") + +cloe.apply_stack [[{ + "version": "4", + "include": [ + "config_nop_infinite.json" + ], + "server": { + "listen": false, + "listen_port": 23456 + } +}]] + +-- All the conditions we want to fail on: +cloe.schedule_these { + run = cloe.actions.fail(), + { on = "virtue/failure" }, + { on = "default_speed/kmph=>0.0" }, + { on = "time=5" }, +} + +-- All the things we want to do on start: +cloe.schedule_these { + on = cloe.events.start(), + { run = "log=info: Running nop/basic smoketest." }, + { run = "realtime_factor=-1" }, +} + +cloe.schedule { + on = cloe.events.loop(), + pin = true, + run = function(sync) + cloe.log(cloe.LogLevel.INFO, "Current time is %s", sync:time()) + end +} + +cloe.schedule_test { + id = "precondition", + on = cloe.events.start(), + run = function(z) + z:printf("hello there") + end +} + +-- Check that schedule_test works as intended. +cloe.schedule_test { + -- Note that this is the same ID as used in BATS. + id = "e03fc31f-586b-4e57-80fa-ff2cba5ff9dd", + on = cloe.events.start(), + terminate = false, + run = function(z, sync) + z:printf("Entering test") + z:assert(true, "true is truthy") + + z:printf("Asserting something...") + z:assert_eq(sync:time():s(), 0, "time at start is 0s") + + z:printf("Waiting 1s...") + z:wait_duration("1s") + + z:assert_ge(sync:time(), cloe.Duration.new("1s"), "yield works and the time has advanced") + z:assert_eq(sync:time(), cloe.Duration.new("1s"), "time has advanced the right amount") + + z:printf("We're good here.") + z:succeed() + end +} diff --git a/tests/test_lua07_schedule_pause.lua b/tests/test_lua07_schedule_pause.lua new file mode 100644 index 000000000..158a8684d --- /dev/null +++ b/tests/test_lua07_schedule_pause.lua @@ -0,0 +1,74 @@ +local cloe = require("cloe") + +cloe.apply_stack { + version = "4", + include = { + "config_nop_infinite.json", + }, + engine = { + security = { + enable_command_action = true + }, + -- This test case will hang in failure, so enable the watchdog. + watchdog = { + mode = "kill", + } + }, + server = { + listen = false, + listen_port = 23456, + }, +} + +cloe.schedule_these { + on = cloe.events.start(), + { + -- Resume from outside + run = { + name = "command", + mode = "async", + command = [[ + curl --retry 10 --retry-connrefused --retry-delay 1 localhost:7890 + ]] + } + }, + { + -- Pause + run = { + name = "command", + command = [[ + echo "Resume with: curl localhost:7890"; + echo OK | netcat -l localhost 7890 + ]] + } + } +} + +-- Alternate implementation wih Lua API +local system = require("cloe.system") +cloe.schedule_these { + on = cloe.events.time("500ms"), + { + -- Resume from outside + run = function() + system.exec { + command = "curl --retry 10 --retry-connrefused --retry-delay 1 localhost:7892", + mode = "detach" + } + end + }, + { + -- Pause + run = function() + system.exec [[ + echo "Resume with: curl localhost:7892"; + echo OK | netcat -l localhost 7892 + ]] + end + } +} + +cloe.schedule { + on = cloe.events.time("1s"), + run = cloe.actions.succeed() +} diff --git a/tests/test_lua08_apply_project.lua b/tests/test_lua08_apply_project.lua new file mode 100644 index 000000000..4079a1bfa --- /dev/null +++ b/tests/test_lua08_apply_project.lua @@ -0,0 +1,40 @@ +local cloe = require("cloe") + +do + local proj = cloe.require("project") + proj.configure_all { + with_server = false, + with_noisy_sensor = true, + } + proj.set_realtime_factor(-1) +end + +cloe.schedule { + on = "loop", + pin = true, + run = function(sync) + cloe.log("info", "Current time is %s", sync:time()) + end +} + +-- Check that schedule_test works as intended. +cloe.schedule_test { + -- Note that this is the same ID as used in BATS. + id = "e03fc31f-586b-4e57-80fa-ff2cba5ff9dd", + on = "start", + run = function(z, sync) + z:printf("Entering test") + z:expect("string") + + z:printf("Asserting something...") + z:assert_eq(sync:time():s(), 0, "time is 0s at start") + + z:printf("Waiting 1s...") + z:wait_duration("1s") + + z:assert_ge(sync:time(), cloe.Duration.new("1s"), "time has advanced") + z:assert_eq(sync:time(), cloe.Duration.new("1s"), "time has advanced exactly 1s") + + z:printf("We're good here.") + end +} diff --git a/tests/test_lua09_no_json.lua b/tests/test_lua09_no_json.lua new file mode 100644 index 000000000..0dbf56abe --- /dev/null +++ b/tests/test_lua09_no_json.lua @@ -0,0 +1,93 @@ +local cloe = require("cloe") + +-- From: config_nop_infinite.json +cloe.apply_stack { + version = "4", + simulators = { + { binding = "nop" } + }, + vehicles = { + { + name = "default", + from = { + simulator = "nop", + index = 0 + }, + components = { + ["cloe::speedometer"] = { + binding = "speedometer", + name = "default_speed", + from = "cloe::gndtruth_ego_sensor" + }, + ["cloe::default_world_sensor"] = { + binding = "noisy_object_sensor", + name = "noisy_object_sensor", + from = "cloe::default_world_sensor", + args = { + noise = { + { + target = "translation", + distribution = { + binding = "normal", + mean = 0.0, + std_deviation = 0.3 + } + } + } + } + } + } + } + } +} + +-- From: controller_basic.json +cloe.apply_stack { + version = "4", + controllers = { + { binding = "basic", vehicle = "default" } + }, + triggers = { + { action = { actions = { "basic/hmi=!enable" }, name = "bundle" }, event = "start" }, + { action = "basic/hmi=enable", event = "next=1" }, + { action = "basic/hmi=resume", event = "time=5" }, + { action = "basic/hmi=!resume", event = "time=5.5" }, + { action = { + name = "insert", + triggers = { { + action = "basic/hmi=plus", + event = "next" + }, { + action = "basic/hmi=!plus", + event = "next=1" + } } + }, + event = "time=6", + label = "Push and release basic/hmi=plus" + } + }, +} + +-- From: controller_virtue.json +cloe.apply_stack { + version = "4", + controllers = { + { binding = "virtue", vehicle = "default" } + }, +} + +-- From: config_nop_smoketest.json +cloe.apply_stack { + version = "4", + server = { + listen = false, + listen_port = 23456 + }, + triggers = { + { action = "fail", event = "virtue/failure" }, + { action = "fail", event = "default_speed/kmph=>0.0", label = "Vehicle default should never move with the nop binding." }, + { action = "log=info: Running nop/basic smoketest.", event = "start" }, + { action = "realtime_factor=-1", event = "start" }, + { action = "succeed", event = "time=60" }, + }, +} diff --git a/tests/test_lua10_heavy_cpu.lua b/tests/test_lua10_heavy_cpu.lua new file mode 100644 index 000000000..bd59f0952 --- /dev/null +++ b/tests/test_lua10_heavy_cpu.lua @@ -0,0 +1,33 @@ +local cloe = require("cloe") +local proj = cloe.require("project") + +proj.configure_all { + with_server = false, +} +proj.stop_after("1s") +proj.set_realtime_factor(-1) + +cloe.schedule { + on = "loop", + pin = true, + run = function() + ARRAY_SIZE = 1000 + + local array = {} + for i = 1, ARRAY_SIZE do + array[i] = math.random() + end + table.sort(array) + end +} + +cloe.schedule { + on = "stop", + run = function(sync) + if sync:achievable_realtime_factor() > 1 then + cloe.execute_action("succeed") + else + cloe.execute_action("fail") + end + end +} diff --git a/tests/test_lua11_serial_tests.lua b/tests/test_lua11_serial_tests.lua new file mode 100644 index 000000000..23dda77ab --- /dev/null +++ b/tests/test_lua11_serial_tests.lua @@ -0,0 +1,47 @@ +local cloe = require("cloe") +local proj = cloe.require("project") + +proj.configure_all { + with_server = false, + with_noisy_sensor = true, +} + +proj.set_realtime_factor(-1) +proj.fail_after("10s") + +TRIGGER_NEXT = false + +-- Check that schedule_test works as intended. +cloe.schedule_test { + id = "TEST-A", + on = "start", + + --- @param z TestFixture + run = function(z) + z:printf("Entering TEST-A") + z:wait_duration("5s") + z:printf("Waited 5 seconds, complete") + _G.TRIGGER_NEXT = true + end +} + +local dur = cloe.Duration.new + +cloe.schedule_test { + id = "TEST-B", + on = function() + cloe.log("debug", "Waiting on TRIGGER_NEXT = %s", TRIGGER_NEXT) + return TRIGGER_NEXT + end, + run = function(z, sync) + z:printf("Entering TEST-B") + z:assert(sync:time() == (dur("5s") + sync:step_width()), "TEST-B should start after TEST-A completed, at 5s") + end +} + +cloe.schedule { + on = "stop", + run = function(sync) + cloe.log("info", "Simulation time is %s", sync:time()) + end +} diff --git a/tests/test_lua12_fail_after_stop.lua b/tests/test_lua12_fail_after_stop.lua new file mode 100644 index 000000000..bcba82bf3 --- /dev/null +++ b/tests/test_lua12_fail_after_stop.lua @@ -0,0 +1,34 @@ +local cloe = require("cloe") +local proj = cloe.require("project") + +proj.configure_all { + with_server = false, + with_noisy_sensor = true, +} + +proj.set_realtime_factor(-1) + +-- Check that schedule_test works as intended. +cloe.schedule_test { + id = "TEST-A", + on = cloe.events.time("1s"), + run = function(z) + z:expect(false, "This has been a bad test!") + end +} + +cloe.schedule_test { + id = "TEST-B", + on = cloe.events.time("5s"), + run = function(z) + z:expect(true, "TEST-B has been a good test!") + end +} + +cloe.schedule { + on = "stop", + run = function(sync) + cloe.log("info", "Checking whether simulation really was successful...") + cloe.log("info", "- Simulation time is %s", sync:time()) + end +} diff --git a/tests/test_lua13_bdd_eval.lua b/tests/test_lua13_bdd_eval.lua new file mode 100644 index 000000000..588a59ce2 --- /dev/null +++ b/tests/test_lua13_bdd_eval.lua @@ -0,0 +1,61 @@ +local cloe = require("cloe") + +do + local project = require("project") + project.configure_all({ + with_server = false, + }) + project.init_report(require("report_config"), { foo = "bar" }) + project.set_realtime_factor(-1) +end + +-- Example schedule_test with testing using the Lust library: +-- +-- https://github.com/bjornbytes/lust +-- +-- Inside the run function we need to use z:describe to start +-- using the Lust library, and we need to import `lust.it` and +-- `lust.expect` from it. +cloe.schedule_test { + id = "d7f31aaa-ccab-421b-a9ae-06aa3835018b", + on = "start", + info = { hello = "test", rqm = "big number" }, + desc = "this is a long text for test", + terminate = false, + run = function(z, sync) + -- If we want to use the Lust + local lust = require("lust") + local it, expect = lust.it, lust.expect + + cloe.log("info", "Entering test") + + z:describe("test group 0", function() + it("", function() + expect(true).to.be.truthy() + end) + end) + + cloe.log("info", "Asserting something...") + + z:describe("test group 1", function() + it("time at start is 0s", function() + expect(sync:time():s()).to.be(0) + end) + end) + + cloe.log("info", "Waiting 1s...") + z:wait_duration("1s") + + z:describe("test group 2", function() + it("yield does not work and the time has not advanced", function() + expect(sync:time() >= cloe.Duration.new("1s")).to.be.truthy() + end) + it("time has advanced the wrong amount", function() + expect(sync:time() == cloe.Duration.new("1s")).to.be.truthy() + end) + end) + + cloe.log("info", "We're good here.") + z:succeed() + end, +} diff --git a/tests/test_lua_api_cloe_system.lua b/tests/test_lua_api_cloe_system.lua new file mode 100644 index 000000000..7c51999cb --- /dev/null +++ b/tests/test_lua_api_cloe_system.lua @@ -0,0 +1,18 @@ +local sys = require("cloe.system") +local ans, ec + +local function endswith(s, suffix) + return string.sub(s, -#suffix) == suffix +end + +ans, ec = sys.exec "echo hello world" +assert(ec == 0) +assert(ans == "hello world") + +ans, ec = sys.exec { + command = "echoxxx hello world", + log_output = "never", +} +assert(ec ~= 0) +print(ans) +assert(endswith(ans, "not found")) diff --git a/tests/test_lua_api_cloe_typecheck.lua b/tests/test_lua_api_cloe_typecheck.lua new file mode 100644 index 000000000..5452ff03e --- /dev/null +++ b/tests/test_lua_api_cloe_typecheck.lua @@ -0,0 +1,11 @@ +local validate = require("cloe.typecheck").validate +local argscheck = require("typecheck").argscheck + +local function log(fmt, ...) + validate("log(string, [?any]...)", fmt, ...) + print(string.format(fmt, ...)) +end + +log("hello world") +log("hello %s", "you") +log("hello %s and %s", "you", "me") diff --git a/tests/test_lua_error_coroutine.lua b/tests/test_lua_error_coroutine.lua new file mode 100644 index 000000000..5e8d6a629 --- /dev/null +++ b/tests/test_lua_error_coroutine.lua @@ -0,0 +1,15 @@ +local cloe = require("cloe") + +local co = coroutine.create(function() + error("expect error") +end) +local ok, result = coroutine.resume(co) +if not ok then + error(result) +end + +cloe.schedule { + desc = "This should not run.", + on = "start", + run = "succeed", +} diff --git a/tests/test_lua_error_main.lua b/tests/test_lua_error_main.lua new file mode 100644 index 000000000..b35a3f85b --- /dev/null +++ b/tests/test_lua_error_main.lua @@ -0,0 +1,9 @@ +local cloe = require("cloe") + +error("expect error") + +cloe.schedule { + desc = "This should not run.", + on = "start", + run = "succeed", +} diff --git a/tests/test_lua_error_schedule.lua b/tests/test_lua_error_schedule.lua new file mode 100644 index 000000000..b598de218 --- /dev/null +++ b/tests/test_lua_error_schedule.lua @@ -0,0 +1,18 @@ +-- This example shows that you don't actually need any plugins at +-- all to have a simulation. You can simple schedule some tasks. +local cloe = require("cloe") + +cloe.load_stackfile("config_nop_infinite.json") + +cloe.schedule { + on = "loop", + run = function() + error("expect error") + end +} + +cloe.schedule { + desc = "This should not run.", + on = "time=1", + run = "succeed", +} diff --git a/tests/test_lua_error_schedule_test.lua b/tests/test_lua_error_schedule_test.lua new file mode 100644 index 000000000..e25a06edc --- /dev/null +++ b/tests/test_lua_error_schedule_test.lua @@ -0,0 +1,17 @@ +local cloe = require("cloe") + +cloe.load_stackfile("config_nop_infinite.json") + +cloe.schedule_test { + id = "9cc0c5a4-5771-4cec-befe-ae49bd3e0cae", + on = "start", + run = function() + error("expect error") + end +} + +cloe.schedule { + desc = "This should not run.", + on = "start", + run = "succeed", +} diff --git a/tests/test_lua_error_segfault_on_resume.lua b/tests/test_lua_error_segfault_on_resume.lua new file mode 100644 index 000000000..ec518bb7b --- /dev/null +++ b/tests/test_lua_error_segfault_on_resume.lua @@ -0,0 +1,75 @@ +local cloe = require("cloe") + +cloe.load_stackfile("config_nop_infinite.json") + +cloe.schedule { + on = "start", + desc = "A demonstration of what works", + enable = true, + run = function() + local test = {} + local resume_main = function () + cloe.log("info", "this should not fail or segfault") + local ok, result = coroutine.resume(test.co) + if not ok then + error(result) + end + end + + test.co = coroutine.create(function() + -- Running scheduler.insert inside this function leads to segfault... + -- Not sure why, so we need to yield the function first. + coroutine.yield(function() + cloe.scheduler.insert { + event = "next", + action = resume_main, + action_source = "simplified scheduler.insert", + } + end) + cloe.log("info", "awesome, it works") + end) + local ok, result = coroutine.resume(test.co) + if not ok then + error(result) + elseif type(result) == "function" then + result() + end + end +} + +cloe.schedule { + on = "start", + desc = "A demonstration of what leads to segfault", + enable = true, + run = function() + local test = {} + local resume_main = function () + cloe.log("info", "this should not fail or segfault, but it does") + local ok, result = coroutine.resume(test.co) + if not ok then + error(result) + end + end + + test.co = coroutine.create(function() + -- Running scheduler.insert inside this function leads to segfault... + -- Not sure why, so we need to yield the function first. + cloe.scheduler.insert { + event = "next", + action = resume_main, + action_source = "simplified scheduler.insert", + } + coroutine.yield() + cloe.log("info", "it's nice that this works now, but unexpected") + end) + local ok, result = coroutine.resume(test.co) + if not ok then + error(result) + end + end +} + +cloe.schedule { + on = "time=0.1", + run = "succeed", +} From e01495838da64c1c03ec4c39f6175d5b802fb261 Mon Sep 17 00:00:00 2001 From: Benjamin Morgan Date: Thu, 16 May 2024 23:28:38 +0200 Subject: [PATCH 09/22] engine: Add LRDB debugging support - Add slimmed down LRDB library to the engine/vendor/lrdb directory. - It can be compiled in or out via option (in by default) - The command-line flags are --debug-lua and --debug-lua-port --- NOTICE.md | 6 + conanfile.py | 3 + docs/reference/lua-initialization.md | 7 +- engine/CMakeLists.txt | 13 +- engine/conanfile.py | 5 + engine/lua/cloe-engine/init.lua | 1 + engine/src/lua_debugger.cpp | 34 + engine/src/lua_setup.cpp | 3 +- engine/src/lua_setup.hpp | 10 + engine/src/main.cpp | 5 + engine/src/main_commands.hpp | 3 + engine/src/main_run.cpp | 10 + engine/src/main_version.cpp | 1 + engine/vendor/lrdb/CMakeLists.txt | 19 + engine/vendor/lrdb/LICENSE_1_0.txt | 23 + engine/vendor/lrdb/NOTICE | 14 + .../vendor/lrdb/include/lrdb/basic_server.hpp | 412 ++++++ engine/vendor/lrdb/include/lrdb/client.hpp | 5 + .../include/lrdb/command_stream/socket.hpp | 129 ++ .../include/lrdb/command_stream/stdstream.hpp | 107 ++ engine/vendor/lrdb/include/lrdb/debugger.hpp | 938 ++++++++++++++ engine/vendor/lrdb/include/lrdb/message.hpp | 249 ++++ engine/vendor/lrdb/include/lrdb/optional.hpp | 261 ++++ engine/vendor/lrdb/include/lrdb/server.hpp | 16 + .../vendor/lrdb/third_party/picojson/LICENSE | 25 + .../lrdb/third_party/picojson/picojson.h | 1105 +++++++++++++++++ tests/conanfile_deployment.py | 1 + 27 files changed, 3401 insertions(+), 4 deletions(-) create mode 100644 engine/src/lua_debugger.cpp create mode 100644 engine/vendor/lrdb/CMakeLists.txt create mode 100644 engine/vendor/lrdb/LICENSE_1_0.txt create mode 100644 engine/vendor/lrdb/NOTICE create mode 100644 engine/vendor/lrdb/include/lrdb/basic_server.hpp create mode 100644 engine/vendor/lrdb/include/lrdb/client.hpp create mode 100644 engine/vendor/lrdb/include/lrdb/command_stream/socket.hpp create mode 100644 engine/vendor/lrdb/include/lrdb/command_stream/stdstream.hpp create mode 100644 engine/vendor/lrdb/include/lrdb/debugger.hpp create mode 100644 engine/vendor/lrdb/include/lrdb/message.hpp create mode 100644 engine/vendor/lrdb/include/lrdb/optional.hpp create mode 100644 engine/vendor/lrdb/include/lrdb/server.hpp create mode 100644 engine/vendor/lrdb/third_party/picojson/LICENSE create mode 100644 engine/vendor/lrdb/third_party/picojson/picojson.h diff --git a/NOTICE.md b/NOTICE.md index 21a0272c0..174d94b63 100644 --- a/NOTICE.md +++ b/NOTICE.md @@ -70,6 +70,12 @@ The following third-party libraries are included in the Cloe repository: - Website: https://github.com/kikito/inspect.lua - Source: engine/lua/inspect.lua +- LRDB + - License: BSL-1.0 + - License-Source: https://www.boost.org/LICENSE_1_0.txt + - Website: https://github.com/satoren/LRDB + - Source: engine/vendor/lrdb + - Lust - License: MIT - License-Source: https://raw.githubusercontent.com/bjornbytes/lust/master/LICENSE diff --git a/conanfile.py b/conanfile.py index 0bdc1a390..3cd7cdcd0 100644 --- a/conanfile.py +++ b/conanfile.py @@ -50,6 +50,7 @@ class Cloe(ConanFile): "fPIC": [True, False], "fable_allow_comments": [True, False], "engine_server": [True, False], + "engine_lrdb": [True, False], "with_esmini": [True, False], "with_vtd": [True, False], } @@ -58,6 +59,7 @@ class Cloe(ConanFile): "fPIC": True, "fable_allow_comments": True, "engine_server": True, + "engine_lrdb": True, "with_esmini": True, "with_vtd": False, } @@ -134,6 +136,7 @@ def generate(self): tc.cache_variables["CLOE_VERSION"] = self.version tc.cache_variables["CLOE_VERSION_U32"] = version_u32 tc.cache_variables["CLOE_ENGINE_WITH_SERVER"] = self.options.engine_server + tc.cache_variables["CLOE_ENGINE_WITH_LRDB"] = self.options.engine_lrdb tc.cache_variables["CLOE_WITH_ESMINI"] = self.options.with_esmini tc.cache_variables["CLOE_WITH_VTD"] = self.options.with_vtd tc.generate() diff --git a/docs/reference/lua-initialization.md b/docs/reference/lua-initialization.md index 570f0eb27..326c46d4b 100644 --- a/docs/reference/lua-initialization.md +++ b/docs/reference/lua-initialization.md @@ -12,6 +12,7 @@ by a Lua file: `cloe-engine run simulation.lua` - Lua package path (`--lua-path`, `CLOE_LUA_PATH`) - Disable system packages (`--no-system-lua`) + - Enable LRDB Lua debugger (`--debug-lua`) - Cloe plugins (`--plugin-path`, `CLOE_PLUGIN_PATH`) 2. Initialize Cloe Stack @@ -25,11 +26,13 @@ by a Lua file: `cloe-engine run simulation.lua` - Expose Cloe API via `cloe` Lua table - Load Cloe Lua runtime (located in the package `lib/cloe/lua` directory) -4. Source input files +4. Start LRDB Lua debugger (Optional) + +5. Source input files - Files ending with `.lua` are merged as Lua - Other files are read as JSON -5. Start simulation +6. Start simulation - Schedule triggers pending from the Lua script diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index b33c387b2..305c21b9f 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -123,6 +123,7 @@ set_target_properties(cloe-enginelib PROPERTIES target_compile_definitions(cloe-enginelib PUBLIC SOL_ALL_SAFETIES_ON=1 + LRDB_USE_BOOST_ASIO=1 CLOE_ENGINE_VERSION="${CLOE_ENGINE_VERSION}" CLOE_ENGINE_TIMESTAMP="${CLOE_ENGINE_TIMESTAMP}" PROJECT_SOURCE_DIR=\"${CMAKE_CURRENT_SOURCE_DIR}\" @@ -155,6 +156,16 @@ else() target_compile_definitions(cloe-enginelib PUBLIC CLOE_ENGINE_WITH_SERVER=0) endif() +option(CLOE_ENGINE_WITH_LRDB "Enable LRDB Lua Debugger?" ON) +if(CLOE_ENGINE_WITH_LRDB) + add_subdirectory(vendor/lrdb) + target_sources(cloe-enginelib PRIVATE src/lua_debugger.cpp) + target_link_libraries(cloe-enginelib PRIVATE lrdb::lrdb) + target_compile_definitions(cloe-enginelib PUBLIC CLOE_ENGINE_WITH_LRDB=1) +else() + target_compile_definitions(cloe-enginelib PUBLIC CLOE_ENGINE_WITH_LRDB=0) +endif() + if(BUILD_TESTING) message(STATUS "Building test-enginelib executable.") add_executable(test-enginelib @@ -176,7 +187,7 @@ if(BUILD_TESTING) endif() # Executable --------------------------------------------------------- -message(STATUS "Building cloe-engine executable [with server=${CLOE_ENGINE_WITH_SERVER}].") +message(STATUS "Building cloe-engine executable [with server=${CLOE_ENGINE_WITH_SERVER}, lrdb=${CLOE_ENGINE_WITH_LRDB}].") add_subdirectory(vendor/linenoise) add_executable(cloe-engine src/main.cpp diff --git a/engine/conanfile.py b/engine/conanfile.py index 4b1e468b5..5ab39b51c 100644 --- a/engine/conanfile.py +++ b/engine/conanfile.py @@ -21,12 +21,16 @@ class CloeEngine(ConanFile): # server dependencies are incompatible with your target system. "server": [True, False], + # Whether the LRDB integration is compiled and built into the Cloe engine. + "lrdb": [True, False], + # Make the compiler be strict and pedantic. # Disable if you upgrade compilers and run into new warnings preventing # the build from completing. May be removed in the future. "pedantic": [True, False], } default_options = { + "lrdb": True, "server": True, "pedantic": True, @@ -72,6 +76,7 @@ def generate(self): tc.cache_variables["CMAKE_EXPORT_COMPILE_COMMANDS"] = True tc.cache_variables["CLOE_PROJECT_VERSION"] = self.version tc.cache_variables["CLOE_ENGINE_WITH_SERVER"] = self.options.server + tc.cache_variables["CLOE_ENGINE_WITH_LRDB"] = self.options.lrdb tc.cache_variables["TargetLintingExtended"] = self.options.pedantic tc.generate() diff --git a/engine/lua/cloe-engine/init.lua b/engine/lua/cloe-engine/init.lua index 4ca995df6..a53ce533d 100644 --- a/engine/lua/cloe-engine/init.lua +++ b/engine/lua/cloe-engine/init.lua @@ -56,6 +56,7 @@ local engine = { ["cloe-stackfile-4.1"] = true, ["cloe-server"] = false, + ["cloe-lrdb"] = false, }, --- @type table Lua table dumped as JSON report at end of simulation. diff --git a/engine/src/lua_debugger.cpp b/engine/src/lua_debugger.cpp new file mode 100644 index 000000000..8f68f8fe2 --- /dev/null +++ b/engine/src/lua_debugger.cpp @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file stack_lua.cpp + */ + +#include "lua_setup.hpp" + +#include // lrdb::server +#include // for state_view + +namespace cloe { + +void start_lua_debugger(sol::state& lua, int listen_port) { + static lrdb::server debug_server(listen_port); + debug_server.reset(lua.lua_state()); +} + +} // namespace cloe diff --git a/engine/src/lua_setup.cpp b/engine/src/lua_setup.cpp index b6b66d48e..b6c21a434 100644 --- a/engine/src/lua_setup.cpp +++ b/engine/src/lua_setup.cpp @@ -198,7 +198,8 @@ void register_cloe_engine(sol::state_view& lua, Stack& stack) { "cloe-stackfile-4.1", true, // Server enabled: - "cloe-server", CLOE_ENGINE_WITH_SERVER != 0 + "cloe-server", CLOE_ENGINE_WITH_SERVER != 0, + "cloe-lrdb", CLOE_ENGINE_WITH_LRDB != 0 ); // clang-format on diff --git a/engine/src/lua_setup.hpp b/engine/src/lua_setup.hpp index 0aa84914d..00e473675 100644 --- a/engine/src/lua_setup.hpp +++ b/engine/src/lua_setup.hpp @@ -58,6 +58,16 @@ struct LuaOptions { */ sol::state new_lua(const LuaOptions& opt, Stack& s); +#if CLOE_ENGINE_WITH_LRDB +/** + * Start Lua debugger server on port. + * + * \param lua + * \param listen_port + */ +void start_lua_debugger(sol::state& lua, int listen_port); +#endif + /** * Merge the provided Lua file into the existing `Stack`, respecting `StackOptions`. * diff --git a/engine/src/main.cpp b/engine/src/main.cpp index 5dda6490b..9db2b24ed 100644 --- a/engine/src/main.cpp +++ b/engine/src/main.cpp @@ -79,6 +79,11 @@ int main(int argc, char** argv) { run->add_flag("--require-success,!--no-require-success", run_options.require_success, "Require simulation success") ->envname("CLOE_REQUIRE_SUCCESS"); + run->add_flag("--debug-lua", run_options.debug_lua, + "Debug the Lua simulation"); + run->add_option("--debug-lua-port", run_options.debug_lua_port, + "Port to listen on for debugger to attach to") + ->envname("CLOE_DEBUG_LUA_PORT"); run->add_option("files", run_files, "Files to merge into a single stackfile")->required(); // One of the above subcommands must be used. diff --git a/engine/src/main_commands.hpp b/engine/src/main_commands.hpp index fcdb60a9c..e701055b4 100644 --- a/engine/src/main_commands.hpp +++ b/engine/src/main_commands.hpp @@ -75,6 +75,9 @@ struct RunOptions { bool write_output = true; bool require_success = false; bool report_progress = true; + + bool debug_lua = false; + int debug_lua_port = 21110; }; int run(const RunOptions& opt, const std::vector& filepaths); diff --git a/engine/src/main_run.cpp b/engine/src/main_run.cpp index 6773624e7..ef9827997 100644 --- a/engine/src/main_run.cpp +++ b/engine/src/main_run.cpp @@ -64,6 +64,16 @@ int run(const RunOptions& opt, const std::vector& filepaths) { // Load the stack file: cloe::Stack stack = cloe::new_stack(opt.stack_options); sol::state lua = cloe::new_lua(opt.lua_options, stack); +#if CLOE_ENGINE_WITH_LRDB + if (opt.debug_lua) { + log->info("Lua debugger listening at port: {}", opt.debug_lua_port); + cloe::start_lua_debugger(lua, opt.debug_lua_port); + } +#else + if (opt.debug_lua) { + log->error("Lua debugger feature not available."); + } +#endif try { cloe::conclude_error(*opt.stack_options.error, [&]() { for (const auto& file : filepaths) { diff --git a/engine/src/main_version.cpp b/engine/src/main_version.cpp index 833c83dca..87e9db335 100644 --- a/engine/src/main_version.cpp +++ b/engine/src/main_version.cpp @@ -45,6 +45,7 @@ int version(const VersionOptions& opt) { {"stack", CLOE_STACK_VERSION}, // from "stack.hpp" {"plugin_manifest", CLOE_PLUGIN_MANIFEST_VERSION}, // from {"feature_server", CLOE_ENGINE_WITH_SERVER != 0}, // from CMakeLists.txt + {"feature_lrdb", CLOE_ENGINE_WITH_LRDB != 0}, // from CMakeLists.txt }; if (opt.output_json) { diff --git a/engine/vendor/lrdb/CMakeLists.txt b/engine/vendor/lrdb/CMakeLists.txt new file mode 100644 index 000000000..f0e9fe339 --- /dev/null +++ b/engine/vendor/lrdb/CMakeLists.txt @@ -0,0 +1,19 @@ +cmake_minimum_required(VERSION 3.15 FATAL_ERROR) + +project(LRDB LANGUAGES CXX) + +add_library(lrdb INTERFACE) +add_library(lrdb::lrdb ALIAS lrdb) +target_include_directories(lrdb + INTERFACE + "$" + "$" +) +target_link_libraries(lrdb + INTERFACE + lua::lua +) +set_target_properties(lrdb PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON +) diff --git a/engine/vendor/lrdb/LICENSE_1_0.txt b/engine/vendor/lrdb/LICENSE_1_0.txt new file mode 100644 index 000000000..36b7cd93c --- /dev/null +++ b/engine/vendor/lrdb/LICENSE_1_0.txt @@ -0,0 +1,23 @@ +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/engine/vendor/lrdb/NOTICE b/engine/vendor/lrdb/NOTICE new file mode 100644 index 000000000..f0a6751b5 --- /dev/null +++ b/engine/vendor/lrdb/NOTICE @@ -0,0 +1,14 @@ +LRDB Modifications +================== + +The LRDB library is sourced from the GitHub repository below: + + - License: BSL-1.0 + - License-Source: https://www.boost.org/LICENSE_1_0.txt + - Website: https://github.com/satoren/LRDB + +The source code has been modified in following ways: + +- Remove files not relevant to our use (e.g. test, node, cmake). +- Replace include/lrdb/debugger.hpp implementation of is_file_path_match. +- Replace CMakeLists.txt with a simplified version. diff --git a/engine/vendor/lrdb/include/lrdb/basic_server.hpp b/engine/vendor/lrdb/include/lrdb/basic_server.hpp new file mode 100644 index 000000000..8d338ab01 --- /dev/null +++ b/engine/vendor/lrdb/include/lrdb/basic_server.hpp @@ -0,0 +1,412 @@ +#pragma once + +#if __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1800) +#include +#include +#include + +#include "debugger.hpp" +#include "message.hpp" + +namespace lrdb { + +#define LRDB_SERVER_PROTOCOL_VERSION "2" + +/// @brief Debug Server Class +/// template type is messaging communication customization point +/// require members +/// void close(); /// connection close +/// bool is_open() const; /// connection is opened +/// void poll(); /// polling event data. Require non blocking +/// void run_one(); /// run event data. Blocking until run one +/// message. +/// void wait_for_connection(); //Blocking until connection. +/// bool send_message(const std::string& message); /// send message to +/// communication opponent +/// //callback functions. Must that call inside poll or run_one +/// std::function on_data;///callback for +/// receiving data. +/// std::function on_connection; +/// std::function on_close; +/// std::function on_error; +template +class basic_server { + public: + /// @brief constructor + /// @param arg Forward to StreamType constructor + template + basic_server(StreamArgs&&... arg) + : wait_for_connect_(true), + command_stream_(std::forward(arg)...) { + init(); + } + + ~basic_server() { exit(); } + + /// @brief attach (or detach) for debug target + /// @param lua_State* debug target + void reset(lua_State* L = 0) { + debugger_.reset(L); + if (!L) { + exit(); + } + } + + /// @brief Exit debug server + void exit() { + send_notify(notify_message("exit")); + command_stream_.close(); + } + + StreamType& command_stream() { return command_stream_; }; + + private: + void init() { + debugger_.set_pause_handler([&](debugger&) { + send_pause_status(); + while (debugger_.paused() && command_stream_.is_open()) { + command_stream_.run_one(); + } + send_notify(notify_message("running")); + }); + + debugger_.set_tick_handler([&](debugger&) { + if (wait_for_connect_) { + command_stream_.wait_for_connection(); + } + command_stream_.poll(); + }); + + command_stream_.on_connection = [=]() { connected_done(); }; + command_stream_.on_data = [=](const std::string& data) { + execute_message(data); + }; + command_stream_.on_close = [=]() { debugger_.unpause(); }; + } + void send_pause_status() { + json::object pauseparam; + pauseparam["reason"] = json::value(debugger_.pause_reason()); + send_notify(notify_message("paused", json::value(pauseparam))); + } + void connected_done() { + wait_for_connect_ = false; + json::object param; + param["protocol_version"] = json::value(LRDB_SERVER_PROTOCOL_VERSION); + + json::object lua; + lua["version"] = json::value(LUA_VERSION); + lua["release"] = json::value(LUA_RELEASE); + lua["copyright"] = json::value(LUA_COPYRIGHT); + + param["lua"] = json::value(lua); + send_notify(notify_message("connected", json::value(param))); + } + + bool send_message(const std::string& message) { + return command_stream_.send_message(message); + } + void execute_message(const std::string& message) { + json::value msg; + std::string err = json::parse(msg, message); + if (err.empty()) { + if (message::is_request(msg)) { + request_message request; + message::parse(msg, request); + execute_request(request); + } + } + } + + bool send_notify(const notify_message& message) { + return send_message(message::serialize(message)); + } + bool send_response(response_message& message) { + return send_message(message::serialize(message)); + } + bool step_request(response_message& response, const json::value&) { + debugger_.step(); + return send_response(response); + } + + bool step_in_request(response_message& response, const json::value&) { + debugger_.step_in(); + return send_response(response); + } + bool step_out_request(response_message& response, const json::value&) { + debugger_.step_out(); + return send_response(response); + } + bool continue_request(response_message& response, const json::value&) { + debugger_.unpause(); + return send_response(response); + } + bool pause_request(response_message& response, const json::value&) { + debugger_.pause(); + return send_response(response); + } + bool add_breakpoint_request(response_message& response, + const json::value& param) { + bool has_source = param.get("file").is(); + bool has_condition = param.get("condition").is(); + bool has_hit_condition = param.get("hit_condition").is(); + bool has_line = param.get("line").is(); + if (has_source && has_line) { + std::string source = + param.get().at("file").get(); + int line = + static_cast(param.get().at("line").get()); + + std::string condition; + std::string hit_condition; + if (has_condition) { + condition = + param.get().at("condition").get(); + } + if (has_hit_condition) { + hit_condition = + param.get().at("hit_condition").get(); + } + debugger_.add_breakpoint(source, line, condition, hit_condition); + + } else { + response.error = + response_error(response_error::InvalidParams, "invalid params"); + } + return send_response(response); + } + + bool clear_breakpoints_request(response_message& response, + const json::value& param) { + bool has_source = param.get("file").is(); + bool has_line = param.get("line").is(); + if (!has_source) { + debugger_.clear_breakpoints(); + } else { + std::string source = + param.get().at("file").get(); + if (!has_line) { + debugger_.clear_breakpoints(source); + } else { + int line = static_cast( + param.get().at("line").get()); + debugger_.clear_breakpoints(source, line); + } + } + + return send_response(response); + } + + bool get_breakpoints_request(response_message& response, const json::value&) { + const debugger::line_breakpoint_type& breakpoints = + debugger_.line_breakpoints(); + + json::array res; + for (const auto& b : breakpoints) { + json::object br; + br["file"] = json::value(b.file); + if (!b.func.empty()) { + br["func"] = json::value(b.func); + } + br["line"] = json::value(double(b.line)); + if (!b.condition.empty()) { + br["condition"] = json::value(b.condition); + } + br["hit_count"] = json::value(double(b.hit_count)); + res.push_back(json::value(br)); + } + + response.result = json::value(res); + + return send_response(response); + } + + bool get_stacktrace_request(response_message& response, const json::value&) { + auto callstack = debugger_.get_call_stack(); + json::array res; + for (auto& s : callstack) { + json::object data; + if (s.source()) { + data["file"] = json::value(s.source()); + } + const char* name = s.name(); + if (!name || name[0] == '\0') { + name = s.name(); + } + if (!name || name[0] == '\0') { + name = s.namewhat(); + } + if (!name || name[0] == '\0') { + name = s.what(); + } + if (!name || name[0] == '\0') { + name = s.source(); + } + data["func"] = json::value(name); + data["line"] = json::value(double(s.currentline())); + data["id"] = json::value(s.short_src()); + res.push_back(json::value(data)); + } + response.result = json::value(res); + + return send_response(response); + } + + bool get_local_variable_request(response_message& response, + const json::value& param) { + if (!param.is()) { + response.error = + response_error(response_error::InvalidParams, "invalid params"); + + return send_response(response); + } + bool has_stackno = param.get("stack_no").is(); + int depth = param.get("depth").is() + ? static_cast(param.get("depth").get()) + : 1; + if (has_stackno) { + int stack_no = static_cast(param.get("stack_no").get()); + auto callstack = debugger_.get_call_stack(); + if (int(callstack.size()) > stack_no) { + auto localvar = callstack[stack_no].get_local_vars(depth); + json::object obj; + for (auto& var : localvar) { + obj[var.first] = var.second; + } + response.result = json::value(obj); + return send_response(response); + } + } + response.error = + response_error(response_error::InvalidParams, "invalid params"); + + return send_response(response); + } + + bool get_upvalues_request(response_message& response, + const json::value& param) { + if (!param.is()) { + response.error = + response_error(response_error::InvalidParams, "invalid params"); + + return send_response(response); + } + bool has_stackno = param.get("stack_no").is(); + int depth = param.get("depth").is() + ? static_cast(param.get("depth").get()) + : 1; + if (has_stackno) { + int stack_no = static_cast( + param.get().at("stack_no").get()); + auto callstack = debugger_.get_call_stack(); + if (int(callstack.size()) > stack_no) { + auto localvar = callstack[stack_no].get_upvalues(depth); + json::object obj; + for (auto& var : localvar) { + obj[var.first] = var.second; + } + + response.result = json::value(obj); + + return send_response(response); + } + } + response.error = + response_error(response_error::InvalidParams, "invalid params"); + + return send_response(response); + } + bool eval_request(response_message& response, const json::value& param) { + bool has_chunk = param.get("chunk").is(); + bool has_stackno = param.get("stack_no").is(); + + bool use_global = + !param.get("global").is() || param.get("global").get(); + bool use_upvalue = + !param.get("upvalue").is() || param.get("upvalue").get(); + bool use_local = + !param.get("local").is() || param.get("local").get(); + + int depth = param.get("depth").is() + ? static_cast(param.get("depth").get()) + : 1; + + if (has_chunk && has_stackno) { + std::string chunk = + param.get().at("chunk").get(); + int stack_no = static_cast( + param.get().at("stack_no").get()); + auto callstack = debugger_.get_call_stack(); + if (int(callstack.size()) > stack_no) { + std::string error; + json::value ret = json::value( + callstack[stack_no].eval(chunk.c_str(), error, use_global, + use_upvalue, use_local, depth + 1)); + if (error.empty()) { + response.result = ret; + + return send_response(response); + } else { + response.error = response_error(response_error::InvalidParams, error); + + return send_response(response); + } + } + } + response.error = + response_error(response_error::InvalidParams, "invalid params"); + + return send_response(response); + } + bool get_global_request(response_message& response, + const json::value& param) { + int depth = param.get("depth").is() + ? static_cast(param.get("depth").get()) + : 1; + response.result = + debugger_.get_global_table(depth + 1); //+ 1 is global table self + + return send_response(response); + } + + void execute_request(const request_message& req) { + typedef bool (basic_server::*exec_cmd_fn)(response_message & response, + const json::value& param); + + static const std::map cmd_map = { +#define LRDB_DEBUG_COMMAND_TABLE(NAME) {#NAME, &basic_server::NAME##_request} + LRDB_DEBUG_COMMAND_TABLE(step), + LRDB_DEBUG_COMMAND_TABLE(step_in), + LRDB_DEBUG_COMMAND_TABLE(step_out), + LRDB_DEBUG_COMMAND_TABLE(continue), + LRDB_DEBUG_COMMAND_TABLE(pause), + LRDB_DEBUG_COMMAND_TABLE(add_breakpoint), + LRDB_DEBUG_COMMAND_TABLE(get_breakpoints), + LRDB_DEBUG_COMMAND_TABLE(clear_breakpoints), + LRDB_DEBUG_COMMAND_TABLE(get_stacktrace), + LRDB_DEBUG_COMMAND_TABLE(get_local_variable), + LRDB_DEBUG_COMMAND_TABLE(get_upvalues), + LRDB_DEBUG_COMMAND_TABLE(eval), + LRDB_DEBUG_COMMAND_TABLE(get_global), +#undef LRDB_DEBUG_COMMAND_TABLE + }; + + response_message response; + response.id = req.id; + auto match = cmd_map.find(req.method); + if (match != cmd_map.end()) { + (this->*(match->second))(response, req.params); + } else { + response.error = response_error(response_error::MethodNotFound, + "method not found : " + req.method); + send_response(response); + } + } + bool wait_for_connect_; + debugger debugger_; + StreamType command_stream_; +}; +} // namespace lrdb + +#else +#error Needs at least a C++11 compiler +#endif \ No newline at end of file diff --git a/engine/vendor/lrdb/include/lrdb/client.hpp b/engine/vendor/lrdb/include/lrdb/client.hpp new file mode 100644 index 000000000..87317fd22 --- /dev/null +++ b/engine/vendor/lrdb/include/lrdb/client.hpp @@ -0,0 +1,5 @@ +#pragma once +#include +#include + +// Not implemented now diff --git a/engine/vendor/lrdb/include/lrdb/command_stream/socket.hpp b/engine/vendor/lrdb/include/lrdb/command_stream/socket.hpp new file mode 100644 index 000000000..a501af924 --- /dev/null +++ b/engine/vendor/lrdb/include/lrdb/command_stream/socket.hpp @@ -0,0 +1,129 @@ +#pragma once + +#include +#include + +#include +#if __cplusplus >= 201103L || defined(_MSC_VER) && _MSC_VER >= 1800 +#else +#define ASIO_HAS_BOOST_DATE_TIME +#define LRDB_USE_BOOST_ASIO +#endif + +#ifdef LRDB_USE_BOOST_ASIO +#include +#else +#define ASIO_STANDALONE +#include +#endif + +namespace lrdb { +#ifdef LRDB_USE_BOOST_ASIO +namespace asio { +using boost::system::error_code; +using namespace boost::asio; +} +#else +#endif + +// one to one server socket +class command_stream_socket { + public: + command_stream_socket(uint16_t port = 21110) + : endpoint_(asio::ip::tcp::v4(), port), + acceptor_(io_service_, endpoint_), + socket_(io_service_) { + async_accept(); + } + + ~command_stream_socket() { + close(); + acceptor_.close(); + } + + void close() { + socket_.close(); + if (on_close) { + on_close(); + } + } + void reconnect() { + close(); + async_accept(); + } + + std::function on_data; + std::function on_connection; + std::function on_close; + std::function on_error; + + bool is_open() const { return socket_.is_open(); } + void poll() { io_service_.poll(); } + void run_one() { io_service_.run_one(); } + void wait_for_connection() { + while (!is_open()) { + io_service_.run_one(); + } + } + + // sync + bool send_message(const std::string& message) { + asio::error_code ec; + std::string data = message + "\r\n"; + asio::write(socket_, asio::buffer(data), ec); + if (ec) { + if (on_error) { + on_error(ec.message()); + } + reconnect(); + return false; + } + return true; + } + + private: + void async_accept() { + acceptor_.async_accept(socket_, [&](const asio::error_code& ec) { + if (!ec) { + connected_done(); + } else { + if (on_error) { + on_error(ec.message()); + } + reconnect(); + } + }); + } + void connected_done() { + if (on_connection) { + on_connection(); + } + start_receive_commands(); + } + void start_receive_commands() { + asio::async_read_until(socket_, read_buffer_, "\n", + [&](const asio::error_code& ec, std::size_t) { + if (!ec) { + std::istream is(&read_buffer_); + std::string command; + std::getline(is, command); + if (on_data) { + on_data(command); + } + start_receive_commands(); + } else { + if (on_error) { + on_error(ec.message()); + } + reconnect(); + } + }); + } + + asio::io_service io_service_; + asio::ip::tcp::endpoint endpoint_; + asio::ip::tcp::acceptor acceptor_; + asio::ip::tcp::socket socket_; + asio::streambuf read_buffer_; +}; +} diff --git a/engine/vendor/lrdb/include/lrdb/command_stream/stdstream.hpp b/engine/vendor/lrdb/include/lrdb/command_stream/stdstream.hpp new file mode 100644 index 000000000..91268ee42 --- /dev/null +++ b/engine/vendor/lrdb/include/lrdb/command_stream/stdstream.hpp @@ -0,0 +1,107 @@ +#include +#include +#include +#include +#include + +// experimental implementation + +#define LRDB_IOSTREAM_PREFIX "lrdb_stream_message:" +namespace lrdb { + +class command_stream_stdstream { + public: + command_stream_stdstream(std::istream& in, std::ostream& out) + : end_(false), istream_(in), ostream_(out) { + thread_ = std::thread([&] { read_thread(); }); + } + ~command_stream_stdstream() { close(); } + + void close() { + { + std::unique_lock lk(mutex_); + end_ = true; + cond_.notify_all(); + } + ostream_ << 1; + if (thread_.joinable()) { + thread_.join(); + } + } + std::function on_data; + std::function on_connection; + std::function on_close; + std::function on_error; + + bool is_open() const { return true; } + void poll() { + std::string msg = pop_message(); + if (!msg.empty()) { + on_data(msg); + } + } + void run_one() { + std::string msg = wait_message(); + if (!msg.empty()) { + on_data(msg); + } + } + void wait_for_connection() {} + + // sync + bool send_message(const std::string& message) { + ostream_ << (LRDB_IOSTREAM_PREFIX + message + "\r\n"); + return true; + } + + private: + std::string pop_message() { + std::unique_lock lk(mutex_); + if (command_buffer_.empty()) { + return ""; + } + + std::string message = std::move(command_buffer_.front()); + command_buffer_.pop_front(); + return message; + } + std::string wait_message() { + std::unique_lock lk(mutex_); + while (command_buffer_.empty() && !end_) { + cond_.wait(lk); + } + if (command_buffer_.empty()) { + return ""; + } + std::string message = std::move(command_buffer_.front()); + command_buffer_.pop_front(); + return message; + } + void push_message(std::string message) { + std::unique_lock lk(mutex_); + command_buffer_.push_back(std::move(message)); + cond_.notify_one(); + } + + void read_thread() { + std::unique_lock lk(mutex_); + std::string msg; + while (!end_) { + mutex_.unlock(); + std::getline(istream_, msg); + if (msg.find(LRDB_IOSTREAM_PREFIX) == 0) { + push_message(msg.substr(sizeof(LRDB_IOSTREAM_PREFIX))); + } + mutex_.lock(); + } + } + + bool end_; + std::istream& istream_; + std::ostream& ostream_; + std::deque command_buffer_; + std::mutex mutex_; + std::condition_variable cond_; + std::thread thread_; +}; +} diff --git a/engine/vendor/lrdb/include/lrdb/debugger.hpp b/engine/vendor/lrdb/include/lrdb/debugger.hpp new file mode 100644 index 000000000..28b9009b2 --- /dev/null +++ b/engine/vendor/lrdb/include/lrdb/debugger.hpp @@ -0,0 +1,938 @@ +#pragma once + +#if __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1800) + +#include +#include +#include +#include +#include + +#include + +#include "picojson.h" +extern "C" { +#include +#include +#include +} + +namespace lrdb { +namespace json { +using namespace ::picojson; +} +#if LUA_VERSION_NUM < 502 +inline int lua_absindex(lua_State* L, int idx) { + return (idx > 0 || (idx <= LUA_REGISTRYINDEX)) ? idx : lua_gettop(L) + 1 + idx; +} +inline size_t lua_rawlen(lua_State* L, int index) { return lua_objlen(L, index); } +inline void lua_pushglobaltable(lua_State* L) { lua_pushvalue(L, LUA_GLOBALSINDEX); } +inline void lua_rawgetp(lua_State* L, int index, void* p) { + lua_pushlightuserdata(L, p); + lua_rawget(L, LUA_REGISTRYINDEX); +} +#endif +namespace utility { + +/// @brief Lua stack value convert to json +inline json::value to_json(lua_State* L, int index, int max_recursive = 1) { + index = lua_absindex(L, index); + int type = lua_type(L, index); + switch (type) { + case LUA_TNIL: + return json::value(); + case LUA_TBOOLEAN: + return json::value(bool(lua_toboolean(L, index) != 0)); + case LUA_TNUMBER: + // todo integer or double + { + double n = lua_tonumber(L, index); + if (std::isnan(n)) { + return json::value("NaN"); + } + if (std::isinf(n)) { + return json::value("Infinity"); + } else { + return json::value(n); + } + } + case LUA_TSTRING: + return json::value(lua_tostring(L, index)); + case LUA_TTABLE: { + if (max_recursive <= 0) { + char buffer[128] = {}; +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + sprintf(buffer, "%p", lua_topointer(L, -1)); +#ifdef _MSC_VER +#pragma warning(pop) +#endif + json::object obj; + + int tt = luaL_getmetafield(L, index, "__name"); + const char* type = (tt == LUA_TSTRING) ? lua_tostring(L, -1) : luaL_typename(L, index); + obj[type] = json::value(buffer); + if (tt != LUA_TNIL) { + lua_pop(L, 1); /* remove '__name' */ + } + return json::value(obj); + } + int array_size = lua_rawlen(L, index); + if (array_size > 0) { + json::array a; + lua_pushnil(L); + while (lua_next(L, index) != 0) { + if (lua_type(L, -2) == LUA_TNUMBER) { + a.push_back(to_json(L, -1, max_recursive - 1)); + } + lua_pop(L, 1); // pop value + } + return json::value(a); + } else { + json::object obj; + lua_pushnil(L); + while (lua_next(L, index) != 0) { + if (lua_type(L, -2) == LUA_TSTRING) { + const char* key = lua_tostring(L, -2); + json::value& b = obj[key]; + + b = to_json(L, -1, max_recursive - 1); + } + lua_pop(L, 1); // pop value + } + return json::value(obj); + } + } + case LUA_TUSERDATA: { + if (luaL_callmeta(L, index, "__tostring")) { + json::value v = to_json(L, -1, max_recursive); // return value to json + lua_pop(L, 1); // pop return value and metatable + return v; + } + if (luaL_callmeta(L, index, "__totable")) { + json::value v = to_json(L, -1, max_recursive); // return value to json + lua_pop(L, 1); // pop return value and metatable + return v; + } + } + case LUA_TLIGHTUSERDATA: + case LUA_TTHREAD: + case LUA_TFUNCTION: { + int tt = luaL_getmetafield(L, index, "__name"); + const char* type = (tt == LUA_TSTRING) ? lua_tostring(L, -1) : luaL_typename(L, index); + char buffer[128] = {}; +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + sprintf(buffer, "%s: %p", type, lua_topointer(L, index)); +#ifdef _MSC_VER +#pragma warning(pop) +#endif + if (tt != LUA_TNIL) { + lua_pop(L, 1); /* remove '__name' */ + } + return json::value(buffer); + } + } + return json::value(); +} +/// @brief push value to Lua stack from json +inline void push_json(lua_State* L, const json::value& v) { + if (v.is()) { + lua_pushnil(L); + } else if (v.is()) { + lua_pushboolean(L, v.get()); + } else if (v.is()) { + lua_pushnumber(L, v.get()); + } else if (v.is()) { + const std::string& str = v.get(); + lua_pushlstring(L, str.c_str(), str.size()); + } else if (v.is()) { + const json::object& obj = v.get(); + lua_createtable(L, 0, obj.size()); + for (json::object::const_iterator itr = obj.begin(); itr != obj.end(); ++itr) { + push_json(L, itr->second); + lua_setfield(L, -2, itr->first.c_str()); + } + } else if (v.is()) { + const json::array& array = v.get(); + lua_createtable(L, array.size(), 0); + for (size_t index = 0; index < array.size(); ++index) { + push_json(L, array[index]); + lua_rawseti(L, -2, index + 1); + } + } +} +} // namespace utility + +/// @brief line based break point type +struct breakpoint_info { + breakpoint_info() : line(-1), hit_count(0) {} + std::string file; /// source file + std::string func; /// function name(currently unused) + int line; /// line number + std::string condition; /// break point condition + std::string hit_condition; // expression that controls how many hits of the + // breakpoint are ignored + size_t hit_count; /// breakpoint hit counts +}; + +/// @brief debug data +/// this data is available per stack frame +class debug_info { + public: + typedef std::vector > local_vars_type; + debug_info() : state_(0), debug_(0) {} + debug_info(const debug_info& other) + : state_(other.state_), debug_(other.debug_), got_debug_(other.got_debug_) {} + debug_info& operator=(const debug_info& other) { + state_ = other.state_; + debug_ = other.debug_; + got_debug_ = other.got_debug_; + return *this; + } + void assign(lua_State* L, lua_Debug* debug, const char* got_type = 0) { + state_ = L; + debug_ = debug; + got_debug_.clear(); + if (got_type) { + got_debug_.append(got_type); + } + if (debug->event == LUA_HOOKLINE) { + got_debug_.append("l"); + } + } + bool is_available_info(const char* type) const { + return got_debug_.find(type) != std::string::npos; + } + bool get_info(const char* type) { + if (!is_available()) { + return 0; + } + if (is_available_info(type)) { + return true; + } + return lua_getinfo(state_, type, debug_) != 0; + } + /// @breaf get name + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + const char* name() { + if (!get_info("n") || !debug_->name) { + return ""; + } + return debug_->name; + } + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + const char* namewhat() { + if (!get_info("n") || !debug_->namewhat) { + return ""; + } + return debug_->namewhat; + } + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + const char* what() { + if (!get_info("S") || !debug_->what) { + return ""; + } + return debug_->what; + } + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + const char* source() { + if (!get_info("S") || !debug_->source) { + return ""; + } + return debug_->source; + } + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + int currentline() { + if (!get_info("l")) { + return -1; + } + return debug_->currentline; + } + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + int linedefined() { + if (!get_info("S")) { + return -1; + } + return debug_->linedefined; + } + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + int lastlinedefined() { + if (!get_info("S")) { + return -1; + } + return debug_->lastlinedefined; + } + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + int number_of_upvalues() { + if (!get_info("u")) { + return -1; + } + return debug_->nups; + } +#if LUA_VERSION_NUM >= 502 + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + int number_of_parameters() { + if (!get_info("u")) { + return -1; + } + return debug_->nparams; + } + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + bool is_variadic_arg() { + if (!get_info("u")) { + return false; + } + return debug_->isvararg != 0; + } + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + bool is_tailcall() { + if (!get_info("t")) { + return false; + } + return debug_->istailcall != 0; + } +#endif + /// @link https://www.lua.org/manual/5.3/manual.html#4.9 + const char* short_src() { + if (!get_info("S")) { + return ""; + } + return debug_->short_src; + } + /* + std::vector valid_lines_on_function() { + std::vector ret; + lua_getinfo(state_, "fL", debug_); + lua_pushnil(state_); + while (lua_next(state_, -2) != 0) { + int t = lua_type(state_, -1); + ret.push_back(lua_toboolean(state_, -1)); + lua_pop(state_, 1); // pop value + } + lua_pop(state_, 2); + return ret; + }*/ + + /// @brief evaluate script + /// e.g. + /// auto ret = debuginfo.eval("return 4,6"); + /// for(auto r: ret){std::cout << r.get() << ,;}//output "4,6," + /// @param script luascript string + /// @param global execute environment include global + /// @param upvalue execute environment include upvalues + /// @param local execute environment include local variables + /// @param object_depth depth of extract for table for return value + /// @return array of name and value pair + std::vector eval(const char* script, bool global = true, bool upvalue = true, + bool local = true, int object_depth = 1) { + std::string error; + std::vector ret = eval(script, error, global, upvalue, local, object_depth); + if (!error.empty()) { + ret.push_back(json::value(error)); + } + return ret; + } + + std::vector eval(const char* script, std::string& error, bool global = true, + bool upvalue = true, bool local = true, int object_depth = 1) { + int stack_start = lua_gettop(state_); + int loadstat = luaL_loadstring(state_, (std::string("return ") + script).c_str()); + if (loadstat != 0) { + lua_pop(state_, 1); + loadstat = luaL_loadstring(state_, script); + } + if (!lua_isfunction(state_, -1)) { + error = lua_tostring(state_, -1); + return std::vector(); + } + + create_eval_env(global, upvalue, local); +#if LUA_VERSION_NUM >= 502 + lua_setupvalue(state_, -2, 1); +#else + lua_setfenv(state_, -2); +#endif + int call_stat = lua_pcall(state_, 0, LUA_MULTRET, 0); + if (call_stat != 0) { + error = lua_tostring(state_, -1); + return std::vector(); + } + std::vector ret; + int ret_end = lua_gettop(state_); + for (int retindex = stack_start + 1; retindex <= ret_end; ++retindex) { + ret.push_back(utility::to_json(state_, retindex, object_depth)); + } + lua_settop(state_, stack_start); + return ret; + } + /// @brief get local variables + /// @param object_depth depth of extract for table for return value + /// @return array of name and value pair + local_vars_type get_local_vars(int object_depth = 0) { + local_vars_type localvars; + int varno = 1; + while (const char* varname = lua_getlocal(state_, debug_, varno++)) { + if (varname[0] != '(') { + localvars.push_back(std::pair( + varname, utility::to_json(state_, -1, object_depth))); + } + lua_pop(state_, 1); + } +#if LUA_VERSION_NUM >= 502 + if (is_variadic_arg()) { + json::array va; + int varno = -1; + while (const char* varname = lua_getlocal(state_, debug_, varno--)) { + (void)varname; // unused + va.push_back(utility::to_json(state_, -1)); + lua_pop(state_, 1); + } + localvars.push_back(std::pair("(*vararg)", json::value(va))); + } +#endif + return localvars; + } + /// @brief set local variables + /// @param name local variable name + /// @param v assign value + /// @return If set is successful, return true. Otherwise return false. + bool set_local_var(const char* name, const json::value& v) { + local_vars_type vars = get_local_vars(); + for (size_t index = 0; index < vars.size(); ++index) { + if (vars[index].first == name) { + return set_local_var(index, v); + } + } + return false; // local variable name not found + } + + /// @brief set local variables + /// @param local_var_index local variable index + /// @param v assign value + /// @return If set is successful, return true. Otherwise return false. + bool set_local_var(int local_var_index, const json::value& v) { + utility::push_json(state_, v); + return lua_setlocal(state_, debug_, local_var_index + 1) != 0; + } + + /// @brief get upvalues + /// @param object_depth depth of extract for table for return value + /// @return array of name and value pair + local_vars_type get_upvalues(int object_depth = 0) { + local_vars_type localvars; + + lua_getinfo(state_, "f", debug_); // push current running function + int upvno = 1; + while (const char* varname = lua_getupvalue(state_, -1, upvno++)) { + localvars.push_back( + std::pair(varname, utility::to_json(state_, -1, object_depth))); + lua_pop(state_, 1); + } + lua_pop(state_, 1); // pop current running function + return localvars; + } + /// @brief set upvalue + /// @param name upvalue name + /// @param v assign value + /// @return If set is successful, return true. Otherwise return false. + bool set_upvalue(const char* name, const json::value& v) { + local_vars_type vars = get_upvalues(); + for (size_t index = 0; index < vars.size(); ++index) { + if (vars[index].first == name) { + return set_upvalue(index, v); + } + } + return false; // local variable name not found + } + + /// @brief set upvalue + /// @param var_index upvalue index + /// @param v assign value + /// @return If set is successful, return true. Otherwise return false. + bool set_upvalue(int var_index, const json::value& v) { + lua_getinfo(state_, "f", debug_); // push current running function + int target_functin_index = lua_gettop(state_); + utility::push_json(state_, v); + bool ret = lua_setupvalue(state_, target_functin_index, var_index + 1) != 0; + lua_pop(state_, 1); // pop current running function + return ret; + } + /// @brief data is available + /// @return If data is available, return true. Otherwise return false. + bool is_available() { return state_ && debug_; } + + private: + void create_eval_env(bool global = true, bool upvalue = true, bool local = true) { + lua_createtable(state_, 0, 0); + int envtable = lua_gettop(state_); + lua_createtable(state_, 0, 0); // create metatable for env + int metatable = lua_gettop(state_); + // use global + if (global) { + lua_pushglobaltable(state_); + lua_setfield(state_, metatable, "__index"); + } + + // use upvalue + if (upvalue) { + lua_getinfo(state_, "f", debug_); // push current running function + +#if LUA_VERSION_NUM < 502 + lua_getfenv(state_, -1); + lua_setfield(state_, metatable, "__index"); +#endif + int upvno = 1; + while (const char* varname = lua_getupvalue(state_, -1, upvno++)) { + if (strcmp(varname, "_ENV") == 0) // override _ENV + { + lua_pushvalue(state_, -1); + lua_setfield(state_, metatable, "__index"); + } + lua_setfield(state_, envtable, varname); + } + lua_pop(state_, 1); // pop current running function + } + // use local vars + if (local) { + int varno = 0; + while (const char* varname = lua_getlocal(state_, debug_, ++varno)) { + if (strcmp(varname, "_ENV") == 0) // override _ENV + { + lua_pushvalue(state_, -1); + lua_setfield(state_, metatable, "__index"); + } + lua_setfield(state_, envtable, varname); + } +#if LUA_VERSION_NUM >= 502 + // va arg + if (is_variadic_arg()) { + varno = 0; + lua_createtable(state_, 0, 0); + while (const char* varname = lua_getlocal(state_, debug_, --varno)) { + (void)varname; // unused + lua_rawseti(state_, -2, -varno); + } + if (varno < -1) { + lua_setfield(state_, envtable, "(*vararg)"); + } else { + lua_pop(state_, 1); + } + } +#endif + } + lua_setmetatable(state_, envtable); +#if LUA_VERSION_NUM < 502 + lua_pushvalue(state_, envtable); + lua_setfield(state_, envtable, "_ENV"); +#endif + return; + } + + friend class debugger; + friend class stack_info; + + lua_State* state_; + lua_Debug* debug_; + std::string got_debug_; +}; + +/// @brief stack frame infomation data +class stack_info : private debug_info { + public: + stack_info(lua_State* L, int level) { + memset(&debug_var_, 0, sizeof(debug_var_)); + valid_ = lua_getstack(L, level, &debug_var_) != 0; + if (valid_) { + assign(L, &debug_var_); + } + } + stack_info(const stack_info& other) + : debug_info(other), debug_var_(other.debug_var_), valid_(other.valid_) { + debug_ = &debug_var_; + } + stack_info& operator=(const stack_info& other) { + debug_info::operator=(other); + debug_var_ = other.debug_var_; + valid_ = other.valid_; + debug_ = &debug_var_; + return *this; + } + bool is_available() { return valid_ && debug_info::is_available(); } + ~stack_info() { debug_ = 0; } + using debug_info::assign; + using debug_info::currentline; + using debug_info::get_info; + using debug_info::is_available_info; + using debug_info::lastlinedefined; + using debug_info::linedefined; + using debug_info::name; + using debug_info::namewhat; + using debug_info::number_of_upvalues; + using debug_info::source; + using debug_info::what; +#if LUA_VERSION_NUM >= 502 + using debug_info::is_tailcall; + using debug_info::is_variadic_arg; + using debug_info::number_of_parameters; +#endif + using debug_info::eval; + using debug_info::get_local_vars; + using debug_info::get_upvalues; + using debug_info::set_local_var; + using debug_info::set_upvalue; + using debug_info::short_src; + + private: + lua_Debug debug_var_; + bool valid_; +}; + +/// @brief Debugging interface class +class debugger { + public: + typedef std::vector line_breakpoint_type; + typedef std::function pause_handler_type; + typedef std::function tick_handler_type; + + debugger() : state_(0), pause_(true), step_type_(STEP_ENTRY) {} + debugger(lua_State* L) : state_(0), pause_(true), step_type_(STEP_ENTRY) { reset(L); } + ~debugger() { reset(); } + + /// @brief add breakpoints + /// @param file filename + /// @param line line number + /// @param condition + /// @param hit_condition start <, <=, ==, >, >=, % , followed by value. If + /// operator is omit, equal to >= + /// e.g. + /// ">5" break always after 5 hits + /// "<5" break on the first 4 hits only + void add_breakpoint(const std::string& file, int line, const std::string& condition = "", + const std::string& hit_condition = "") { + breakpoint_info info; + info.file = file; + info.line = line; + info.condition = condition; + if (!hit_condition.empty()) { + if (is_first_cond_operators(hit_condition)) { + info.hit_condition = hit_condition; + } else { + info.hit_condition = ">=" + hit_condition; + } + } + + line_breakpoints_.push_back(info); + } + /// @brief clear breakpoints with filename and line number + /// @param file source filename + /// @param line If minus,ignore line number. default -1 + void clear_breakpoints(const std::string& file, int line = -1) { + line_breakpoints_.erase(std::remove_if(line_breakpoints_.begin(), line_breakpoints_.end(), + [&](const breakpoint_info& b) { + return (line < 0 || b.line == line) && + (b.file == file); + }), + line_breakpoints_.end()); + } + /// @brief clear breakpoints + void clear_breakpoints() { line_breakpoints_.clear(); } + + /// @brief get line break points. + /// @return array of breakpoints. + const line_breakpoint_type& line_breakpoints() const { return line_breakpoints_; } + + // void error_break(bool enable) { error_break_ = enable; } + + /// @brief set tick handler. callback at new line,function call and function + /// return. + void set_tick_handler(tick_handler_type handler) { tick_handler_ = handler; } + + /// @brief set pause handler. callback at paused by pause,step,breakpoint. + /// If want continue pause,execute the loop so as not to return. + /// e.g. basic_server::init + void set_pause_handler(pause_handler_type handler) { pause_handler_ = handler; } + + /// @brief get current debug info,i.e. executing stack frame top. + debug_info& current_debug_info() { return current_debug_info_; } + + /// @brief get breakpoint + breakpoint_info* current_breakpoint() { return current_breakpoint_; } + + /// @brief assign or unassign debug target + void reset(lua_State* L = 0) { + if (state_ != L) { + if (state_) { + unsethook(); + } + state_ = L; + if (state_) { + sethook(); + } + } + } + /// @brief pause + void pause() { step_type_ = STEP_PAUSE; } + /// @brief unpause(continue) + void unpause() { + pause_ = false; + step_type_ = STEP_NONE; + } + /// @brief paused + /// @return If paused, return true. Otherwise return false. + bool paused() { return pause_; } + + /// @brief paused + /// @return string for pause + /// reason."breakpoint","step","step_in","step_out","exception" + const char* pause_reason() { + if (current_breakpoint_) { + return "breakpoint"; + } else if (step_type_ == STEP_OVER) { + return "step"; + } else if (step_type_ == STEP_IN) { + return "step_in"; + } else if (step_type_ == STEP_OUT) { + return "step_out"; + } else if (step_type_ == STEP_PAUSE) { + return "pause"; + } else if (step_type_ == STEP_ENTRY) { + return "entry"; + } + + return "exception"; + } + + /// @brief step. same for step_over + void step() { step_over(); } + + /// @brief step_over + void step_over() { + step_type_ = STEP_OVER; + step_callstack_size_ = get_call_stack().size(); + pause_ = false; + } + /// @brief step in + void step_in() { + step_type_ = STEP_IN; + step_callstack_size_ = get_call_stack().size(); + pause_ = false; + } + /// @brief step out + void step_out() { + step_type_ = STEP_OUT; + step_callstack_size_ = get_call_stack().size(); + pause_ = false; + } + /// @brief get call stack info + /// @return array of call stack information + std::vector get_call_stack() { + std::vector ret; + if (!current_debug_info_.state_) { + return ret; + } + ret.push_back(stack_info(current_debug_info_.state_, 0)); + while (ret.back().is_available()) { + ret.push_back(stack_info(current_debug_info_.state_, ret.size())); + } + ret.pop_back(); + return ret; + } + + /// @brief get global table + /// @param object_depth depth of extract for return value + /// @return global table value + json::value get_global_table(int object_depth = 1) { + lua_pushglobaltable(state_); + json::value v = utility::to_json(state_, -1, object_depth); + lua_pop(state_, 1); // pop global table + return v; + } + + private: + void sethook() { + lua_pushlightuserdata(state_, this_data_key()); + lua_pushlightuserdata(state_, this); + lua_rawset(state_, LUA_REGISTRYINDEX); + lua_sethook(state_, &hook_function, LUA_MASKCALL | LUA_MASKRET | LUA_MASKLINE, 0); + } + void unsethook() { + if (state_) { + lua_sethook(state_, 0, 0, 0); + lua_pushlightuserdata(state_, this_data_key()); + lua_pushnil(state_); + lua_rawset(state_, LUA_REGISTRYINDEX); + state_ = 0; + } + } + debugger(const debugger&); //=delete; + debugger& operator=(const debugger&); //=delete; + + static bool is_path_separator(char c) { return c == '\\' || c == '/'; } + static bool is_file_path_match(const char* path1, const char* path2) { + auto p1 = std::filesystem::canonical(std::filesystem::path(std::string(path1))); + auto p2 = std::filesystem::canonical(std::filesystem::path(std::string(path2))); + bool match = p1 == p2; + return match; + } + + breakpoint_info* search_breakpoints(debug_info& debuginfo) { + if (line_breakpoints_.empty()) { + return 0; + } + int currentline = debuginfo.currentline(); + for (line_breakpoint_type::iterator it = line_breakpoints_.begin(); + it != line_breakpoints_.end(); + ++it) { + if (currentline == it->line) { + const char* source = debuginfo.source(); + if (!source) { + continue; + } + // remove front @ + if (source[0] == '@') { + source++; + } + if (is_file_path_match(it->file.c_str(), source)) { + return &(*it); + } + } + } + return 0; + } + bool breakpoint_cond(const breakpoint_info& breakpoint, debug_info& debuginfo) { + if (!breakpoint.condition.empty()) { + json::array condret = debuginfo.eval(breakpoint.condition.c_str()); + return !condret.empty() && condret[0].evaluate_as_boolean(); + } + return true; + } + static bool is_first_cond_operators(const std::string& cond) { + const char* ops[] = {"<", "==", ">", "%"}; //,"<=" ,">=" + for (size_t i = 0; i < sizeof(ops) / sizeof(ops[0]); ++i) { + if (cond.compare(0, strlen(ops[i]), ops[i]) == 0) { + return true; + } + } + return false; + } + + bool breakpoint_hit_cond(const breakpoint_info& breakpoint, debug_info& debuginfo) { + if (!breakpoint.hit_condition.empty()) { + json::array condret = + debuginfo.eval((std::to_string(breakpoint.hit_count) + breakpoint.hit_condition).c_str()); + + return condret.empty() || condret[0].evaluate_as_boolean(); + } + return true; + } + void hookline() { + current_breakpoint_ = search_breakpoints(current_debug_info_); + if (current_breakpoint_ && breakpoint_cond(*current_breakpoint_, current_debug_info_)) { + current_breakpoint_->hit_count++; + if (breakpoint_hit_cond(*current_breakpoint_, current_debug_info_)) { + pause_ = true; + } + } + } + void hookcall() {} + void hookret() {} + + void tick() { + if (tick_handler_) { + tick_handler_(*this); + } + } + void check_code_step_pause() { + if (step_type_ == STEP_NONE) { + return; + } + + std::vector callstack = get_call_stack(); + switch (step_type_) { + case STEP_OVER: + if (step_callstack_size_ >= callstack.size()) { + pause_ = true; + } + break; + case STEP_IN: + pause_ = true; + break; + case STEP_OUT: + if (step_callstack_size_ > callstack.size()) { + pause_ = true; + } + break; + case STEP_PAUSE: + pause_ = true; + break; + case STEP_ENTRY: + case STEP_NONE: + break; + } + } + void hook(lua_State* L, lua_Debug* ar) { + current_debug_info_.assign(L, ar); + current_breakpoint_ = 0; + tick(); + + if (!pause_ && ar->event == LUA_HOOKLINE) { + check_code_step_pause(); + } + + if (ar->event == LUA_HOOKLINE) { + hookline(); + } else if (ar->event == LUA_HOOKCALL) { + hookcall(); + } else if (ar->event == LUA_HOOKRET) { + hookret(); + } + if (pause_ && pause_handler_) { + step_callstack_size_ = 0; + pause_handler_(*this); + if (step_type_ == STEP_NONE) { + pause_ = false; + } + } + } + static void* this_data_key() { + static int key_data = 0; + return &key_data; + } + static void hook_function(lua_State* L, lua_Debug* ar) { + lua_rawgetp(L, LUA_REGISTRYINDEX, this_data_key()); + + debugger* self = static_cast(lua_touserdata(L, -1)); + lua_pop(L, 1); + self->hook(L, ar); + } + + enum step_type { + STEP_NONE, + STEP_OVER, + STEP_IN, + STEP_OUT, + STEP_PAUSE, + STEP_ENTRY, + }; + + lua_State* state_; + bool pause_; + // bool error_break_; + step_type step_type_; + size_t step_callstack_size_; + debug_info current_debug_info_; + line_breakpoint_type line_breakpoints_; + breakpoint_info* current_breakpoint_; + pause_handler_type pause_handler_; + tick_handler_type tick_handler_; +}; +} // namespace lrdb + +#else +#error Needs at least a C++11 compiler +#endif diff --git a/engine/vendor/lrdb/include/lrdb/message.hpp b/engine/vendor/lrdb/include/lrdb/message.hpp new file mode 100644 index 000000000..c377f674b --- /dev/null +++ b/engine/vendor/lrdb/include/lrdb/message.hpp @@ -0,0 +1,249 @@ +#pragma once + +#include "picojson.h" +#include "lrdb/optional.hpp" + +namespace lrdb { +namespace json { +using namespace ::picojson; +} + +namespace message { + +struct request_message { + request_message() {} + request_message(std::string id, std::string method, + json::value params = json::value()) + : id(std::move(id)), + method(std::move(method)), + params(std::move(params)) {} + request_message(int id, std::string method, + json::value params = json::value()) + : id(double(id)), method(std::move(method)), params(std::move(params)) {} + json::value id; + std::string method; + json::value params; +}; + +struct response_error { + int code; + std::string message; + json::value data; + + response_error(int code, std::string message) + : code(code), message(std::move(message)) {} + + enum error_code { + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + serverErrorStart = -32099, + serverErrorEnd = -32000, + ServerNotInitialized, + UnknownErrorCode = -32001 + }; +}; + +struct response_message { + response_message() {} + response_message(std::string id, json::value result = json::value()) + : id(std::move(id)), result(std::move(result)) {} + response_message(int id, json::value result = json::value()) + : id(double(id)), result(std::move(result)) {} + json::value id; + json::value result; + optional error; +}; +struct notify_message { + notify_message(std::string method, json::value params = json::value()) + : method(std::move(method)), params(std::move(params)) {} + std::string method; + json::value params; +}; + +inline bool is_notify(const json::value& msg) { + return msg.is() && !msg.contains("id"); +} +inline bool is_request(const json::value& msg) { + return msg.is() && msg.contains("method") && + msg.get("method").is(); +} +inline bool is_response(const json::value& msg) { + return msg.is() && !msg.contains("method") && + msg.contains("id"); +} + +inline bool parse(const json::value& message, request_message& request) { + if (!is_request(message)) { + return false; + } + request.id = message.get("id"); + request.method = message.get("method").get(); + if (message.contains("param")) { + request.params = message.get("param"); + } else { + request.params = message.get("params"); + } + return true; +} + +inline bool parse(const json::value& message, notify_message& notify) { + if (!is_notify(message)) { + return false; + } + notify.method = message.get("method").get(); + if (message.contains("param")) { + notify.params = message.get("param"); + } else { + notify.params = message.get("params"); + } + return true; +} + +inline bool parse(const json::value& message, response_message& response) { + if (!is_response(message)) { + return false; + } + response.id = message.get("id"); + response.result = message.get("result"); + return true; +} + +inline std::string serialize(const request_message& msg) { + json::object obj; + obj["jsonrpc"] = json::value("2.0"); + + obj["method"] = json::value(msg.method); + if (!msg.params.is()) { + obj["params"] = msg.params; + } + obj["id"] = msg.id; + return json::value(obj).serialize(); +} + +inline std::string serialize(const response_message& msg) { + json::object obj; + obj["jsonrpc"] = json::value("2.0"); + + obj["result"] = msg.result; + + if (msg.error) { + json::object error = {{"code", json::value(double(msg.error->code))}, + {"message", json::value(msg.error->message)}, + {"data", json::value(msg.error->data)}}; + obj["error"] = json::value(error); + } + + obj["id"] = msg.id; + return json::value(obj).serialize(); +} + +inline std::string serialize(const notify_message& msg) { + json::object obj; + obj["jsonrpc"] = json::value("2.0"); + + obj["method"] = json::value(msg.method); + if (!msg.params.is()) { + obj["params"] = msg.params; + } + return json::value(obj).serialize(); +} + +inline const std::string& get_method(const json::value& msg) { + static std::string null; + if (!msg.is() || !msg.contains("method")) { + return null; + } + const json::value& m = msg.get().at("method"); + if (!m.is()) { + return null; + } + return m.get(); +} +inline const json::value& get_param(const json::value& msg) { + static json::value null; + if (!msg.is() || !msg.contains("params")) { + return null; + } + return msg.get().at("params"); +} +inline const json::value& get_id(const json::value& msg) { + static json::value null; + if (!msg.is() || !msg.contains("id")) { + return null; + } + return msg.get().at("id"); +} +namespace request { +inline std::string serialize(const json::value& id, const std::string& medhod, + const json::value& param = json::value()) { + json::object obj; + obj["method"] = json::value(medhod); + if (!param.is()) { + obj["params"] = param; + } + obj["id"] = id; + return json::value(obj).serialize(); +} +inline std::string serialize(const json::value& id, const std::string& medhod, + const std::string& param) { + return serialize(id, medhod, json::value(param)); +} +inline std::string serialize(double id, const std::string& medhod, + const std::string& param) { + return serialize(json::value(id), medhod, json::value(param)); +} +inline std::string serialize(double id, const std::string& medhod, + const json::value& param = json::value()) { + return serialize(json::value(id), medhod, json::value(param)); +} +inline std::string serialize(const std::string& id, const std::string& medhod, + const std::string& param) { + return serialize(json::value(id), medhod, json::value(param)); +} +inline std::string serialize(const std::string& id, const std::string& medhod, + const json::value& param = json::value()) { + return serialize(json::value(id), medhod, json::value(param)); +} +} +namespace notify { +inline std::string serialize(const std::string& medhod, + const json::value& param = json::value()) { + json::object obj; + obj["method"] = json::value(medhod); + if (!param.is()) { + obj["params"] = param; + } + return json::value(obj).serialize(); +} +inline std::string serialize(const std::string& medhod, + const std::string& param) { + return serialize(medhod, json::value(param)); +} +} +namespace responce { +inline std::string serialize(const json::value& id, + const json::value& result = json::value(), + bool error = false) { + json::object obj; + if (error) { + obj["error"] = result; + } else { + obj["result"] = result; + } + obj["id"] = id; + return json::value(obj).serialize(); +} +inline std::string serialize(const json::value& id, const std::string& result, + bool error = false) { + return serialize(id, json::value(result), error); +} +} +} +using message::request_message; +using message::response_message; +using message::notify_message; +using message::response_error; +} diff --git a/engine/vendor/lrdb/include/lrdb/optional.hpp b/engine/vendor/lrdb/include/lrdb/optional.hpp new file mode 100644 index 000000000..2fd4b4421 --- /dev/null +++ b/engine/vendor/lrdb/include/lrdb/optional.hpp @@ -0,0 +1,261 @@ +// Copyright satoren +// Distributed under the Boost Software License, Version 1.0. (See +// accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#pragma once +#include + +namespace lrdb +{ + /// @addtogroup optional + /// @{ + + struct bad_optional_access :std::exception {}; + struct nullopt_t {}; + + /// @brief self implement for std::optional(C++17 feature). + templateclass optional + { + typedef void (optional::*bool_type)() const; + void this_type_does_not_support_comparisons() const {} + public: + optional() : value_(0) {}; + optional(nullopt_t) : value_(0) {}; + optional(const optional& other) : value_(0) + { + if (other) + { + value_ = new(&storage_) T(other.value()); + } + } + optional(const T& value) + { + value_ = new(&storage_) T(value); + } + + ~optional() { + destruct(); + } + optional& operator=(nullopt_t) { destruct(); return *this; } + optional& operator=(const optional& other) + { + if (other) + { + *this = other.value(); + } + else + { + destruct(); + } + return *this; + } + optional& operator=(const T& value) + { + if (value_) + { + *value_ = value; + } + else + { + value_ = new(&storage_) T(value); + } + return *this; + } + +#if KAGUYA_USE_CPP11 + optional(optional&& other) :value_(0) + { + if (other) + { + value_ = new(&storage_) T(std::move(other.value())); + } + } + optional(T&& value) + { + value_ = new(&storage_) T(std::move(value)); + } + optional& operator=(optional&& other) + { + if (other) + { + *this = std::move(other.value()); + } + else + { + destruct(); + } + return *this; + } + optional& operator=(T&& value) + { + if (value_) + { + *value_ = std::move(value); + } + else + { + value_ = new(&storage_) T(std::move(value)); + } + return *this; + } +#endif + + operator bool_type() const + { + this_type_does_not_support_comparisons(); + return value_ != 0 ? &optional::this_type_does_not_support_comparisons : 0; + } + T& value() + { + if (value_) { return *value_; } + throw bad_optional_access(); + } + const T & value() const + { + if (value_) { return *value_; } + throw bad_optional_access(); + } + +#if KAGUYA_USE_CPP11 + template< class U > + T value_or(U&& default_value) const + { + if (value_) { return *value_; } + return default_value; + } +#else + template< class U > + T value_or(const U& default_value)const + { + if (value_) { return *value_; } + return default_value; + } +#endif + const T* operator->() const { assert(value_); return value_; } + T* operator->() { assert(value_); return value_; } + const T& operator*() const { assert(value_); return *value_; } + T& operator*() { assert(value_); return *value_; } + private: + void destruct() + { + if (value_) + { + value_->~T(); value_ = 0; + } + } + + typename std::aligned_storage::value>::type storage_; + + T* value_; + }; + + /// @brief specialize optional for reference. + /// sizeof(optional) == sizeof(T*) + templateclass optional + { + typedef void (optional::*bool_type)() const; + void this_type_does_not_support_comparisons() const {} + public: + optional() : value_(0) {}; + optional(nullopt_t) : value_(0) {}; + + optional(const optional& other) :value_(other.value_) { } + optional(T& value) :value_(&value) { } + + ~optional() { + } + optional& operator=(nullopt_t) { + value_ = 0; + return *this; + } + optional& operator=(const optional& other) + { + value_ = other.value_; + return *this; + } + optional& operator=(T& value) + { + value_ = &value; + return *this; + } + operator bool_type() const + { + this_type_does_not_support_comparisons(); + return value_ != 0 ? &optional::this_type_does_not_support_comparisons : 0; + } + T& value() + { + if (value_) { return *value_; } + throw bad_optional_access(); + } + const T & value() const + { + if (value_) { return *value_; } + throw bad_optional_access(); + } + +#if KAGUYA_USE_CPP11 + T& value_or(T& default_value) const + { + if (value_) { return *value_; } + return default_value; + } +#else + T& value_or(T& default_value)const + { + if (value_) { return *value_; } + return default_value; + } +#endif + + const T* operator->() const { assert(value_); return value_; } + T* operator->() { assert(value_); return value_; } + const T& operator*() const { assert(value_); return *value_; } + T& operator*() { assert(value_); return *value_; } + private: + T* value_; + }; + + /// @name relational operators + /// @brief + ///@{ + template< class T > + bool operator==(const optional& lhs, const optional& rhs) + { + if (bool(lhs) != bool(rhs)) { return false; } + if (bool(lhs) == false) { return true; } + return lhs.value() == rhs.value(); + } + template< class T > + bool operator!=(const optional& lhs, const optional& rhs) + { + return !(lhs == rhs); + } + template< class T > + bool operator<(const optional& lhs, const optional& rhs) + { + if (!bool(rhs)) { return false; } + if (!bool(lhs)) { return true; } + return lhs.value() < rhs.value(); + } + template< class T > + bool operator<=(const optional& lhs, const optional& rhs) + { + return !(rhs < lhs); + } + template< class T > + bool operator>(const optional& lhs, const optional& rhs) + { + return rhs < lhs; + } + template< class T > + bool operator>=(const optional& lhs, const optional& rhs) + { + return !(lhs < rhs); + } + /// @} + + /// @} +} \ No newline at end of file diff --git a/engine/vendor/lrdb/include/lrdb/server.hpp b/engine/vendor/lrdb/include/lrdb/server.hpp new file mode 100644 index 000000000..fdda423b9 --- /dev/null +++ b/engine/vendor/lrdb/include/lrdb/server.hpp @@ -0,0 +1,16 @@ +#pragma once + +#if __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1800) +#include +#include +#include + +#include "basic_server.hpp" +#include "command_stream/socket.hpp" +namespace lrdb { +typedef basic_server server; +} + +#else +#error Needs at least a C++11 compiler +#endif \ No newline at end of file diff --git a/engine/vendor/lrdb/third_party/picojson/LICENSE b/engine/vendor/lrdb/third_party/picojson/LICENSE new file mode 100644 index 000000000..72f355391 --- /dev/null +++ b/engine/vendor/lrdb/third_party/picojson/LICENSE @@ -0,0 +1,25 @@ +Copyright 2009-2010 Cybozu Labs, Inc. +Copyright 2011-2014 Kazuho Oku +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/engine/vendor/lrdb/third_party/picojson/picojson.h b/engine/vendor/lrdb/third_party/picojson/picojson.h new file mode 100644 index 000000000..14c580b68 --- /dev/null +++ b/engine/vendor/lrdb/third_party/picojson/picojson.h @@ -0,0 +1,1105 @@ +/* + * Copyright 2009-2010 Cybozu Labs, Inc. + * Copyright 2011-2014 Kazuho Oku + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef picojson_h +#define picojson_h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// for isnan/isinf +#if __cplusplus>=201103L +# include +#else +extern "C" { +# ifdef _MSC_VER +# include +# elif defined(__INTEL_COMPILER) +# include +# else +# include +# endif +} +#endif + +#ifndef PICOJSON_USE_RVALUE_REFERENCE +# if (defined(__cpp_rvalue_references) && __cpp_rvalue_references >= 200610) || (defined(_MSC_VER) && _MSC_VER >= 1600) +# define PICOJSON_USE_RVALUE_REFERENCE 1 +# else +# define PICOJSON_USE_RVALUE_REFERENCE 0 +# endif +#endif//PICOJSON_USE_RVALUE_REFERENCE + + +// experimental support for int64_t (see README.mkdn for detail) +#ifdef PICOJSON_USE_INT64 +# define __STDC_FORMAT_MACROS +# include +# include +#endif + +// to disable the use of localeconv(3), set PICOJSON_USE_LOCALE to 0 +#ifndef PICOJSON_USE_LOCALE +# define PICOJSON_USE_LOCALE 1 +#endif +#if PICOJSON_USE_LOCALE +extern "C" { +# include +} +#endif + +#ifndef PICOJSON_ASSERT +# define PICOJSON_ASSERT(e) do { if (! (e)) throw std::runtime_error(#e); } while (0) +#endif + +#ifdef _MSC_VER + #define SNPRINTF _snprintf_s + #pragma warning(push) + #pragma warning(disable : 4244) // conversion from int to char + #pragma warning(disable : 4127) // conditional expression is constant + #pragma warning(disable : 4702) // unreachable code +#else + #define SNPRINTF snprintf +#endif + +namespace picojson { + + enum { + null_type, + boolean_type, + number_type, + string_type, + array_type, + object_type +#ifdef PICOJSON_USE_INT64 + , int64_type +#endif + }; + + enum { + INDENT_WIDTH = 2 + }; + + struct null {}; + + class value { + public: + typedef std::vector array; + typedef std::map object; + union _storage { + bool boolean_; + double number_; +#ifdef PICOJSON_USE_INT64 + int64_t int64_; +#endif + std::string* string_; + array* array_; + object* object_; + }; + protected: + int type_; + _storage u_; + public: + value(); + value(int type, bool); + explicit value(bool b); +#ifdef PICOJSON_USE_INT64 + explicit value(int64_t i); +#endif + explicit value(double n); + explicit value(const std::string& s); + explicit value(const array& a); + explicit value(const object& o); +#if PICOJSON_USE_RVALUE_REFERENCE + explicit value(std::string&& s); + explicit value(array&& a); + explicit value(object&& o); +#endif + explicit value(const char* s); + value(const char* s, size_t len); + ~value(); + value(const value& x); + value& operator=(const value& x); +#if PICOJSON_USE_RVALUE_REFERENCE + value(value&& x)throw(); + value& operator=(value&& x)throw(); +#endif + void swap(value& x)throw(); + template bool is() const; + template const T& get() const; + template T& get(); + template void set(const T &); +#if PICOJSON_USE_RVALUE_REFERENCE + template void set(T &&); +#endif + bool evaluate_as_boolean() const; + const value& get(size_t idx) const; + const value& get(const std::string& key) const; + value& get(size_t idx); + value& get(const std::string& key); + + bool contains(size_t idx) const; + bool contains(const std::string& key) const; + std::string to_str() const; + template void serialize(Iter os, bool prettify = false) const; + std::string serialize(bool prettify = false) const; + private: + template value(const T*); // intentionally defined to block implicit conversion of pointer to bool + template static void _indent(Iter os, int indent); + template void _serialize(Iter os, int indent) const; + std::string _serialize(int indent) const; + void clear(); + }; + + typedef value::array array; + typedef value::object object; + + inline value::value() : type_(null_type) {} + + inline value::value(int type, bool) : type_(type) { + switch (type) { +#define INIT(p, v) case p##type: u_.p = v; break + INIT(boolean_, false); + INIT(number_, 0.0); +#ifdef PICOJSON_USE_INT64 + INIT(int64_, 0); +#endif + INIT(string_, new std::string()); + INIT(array_, new array()); + INIT(object_, new object()); +#undef INIT + default: break; + } + } + + inline value::value(bool b) : type_(boolean_type) { + u_.boolean_ = b; + } + +#ifdef PICOJSON_USE_INT64 + inline value::value(int64_t i) : type_(int64_type) { + u_.int64_ = i; + } +#endif + + inline value::value(double n) : type_(number_type) { + if ( +#ifdef _MSC_VER + ! _finite(n) +#elif __cplusplus>=201103L || !(defined(isnan) && defined(isinf)) + std::isnan(n) || std::isinf(n) +#else + isnan(n) || isinf(n) +#endif + ) { + throw std::overflow_error(""); + } + u_.number_ = n; + } + + inline value::value(const std::string& s) : type_(string_type) { + u_.string_ = new std::string(s); + } + + inline value::value(const array& a) : type_(array_type) { + u_.array_ = new array(a); + } + + inline value::value(const object& o) : type_(object_type) { + u_.object_ = new object(o); + } + +#if PICOJSON_USE_RVALUE_REFERENCE + inline value::value(std::string&& s) : type_(string_type) { + u_.string_ = new std::string(std::move(s)); + } + + inline value::value(array&& a) : type_(array_type) { + u_.array_ = new array(std::move(a)); + } + + inline value::value(object&& o) : type_(object_type) { + u_.object_ = new object(std::move(o)); + } +#endif + + inline value::value(const char* s) : type_(string_type) { + u_.string_ = new std::string(s); + } + + inline value::value(const char* s, size_t len) : type_(string_type) { + u_.string_ = new std::string(s, len); + } + + inline void value::clear() { + switch (type_) { +#define DEINIT(p) case p##type: delete u_.p; break + DEINIT(string_); + DEINIT(array_); + DEINIT(object_); +#undef DEINIT + default: break; + } + } + + inline value::~value() { + clear(); + } + + inline value::value(const value& x) : type_(x.type_) { + switch (type_) { +#define INIT(p, v) case p##type: u_.p = v; break + INIT(string_, new std::string(*x.u_.string_)); + INIT(array_, new array(*x.u_.array_)); + INIT(object_, new object(*x.u_.object_)); +#undef INIT + default: + u_ = x.u_; + break; + } + } + + inline value& value::operator=(const value& x) { + if (this != &x) { + value t(x); + swap(t); + } + return *this; + } + +#if PICOJSON_USE_RVALUE_REFERENCE + inline value::value(value&& x)throw() : type_(null_type) { + swap(x); + } + inline value& value::operator=(value&& x)throw() { + swap(x); + return *this; + } +#endif + inline void value::swap(value& x)throw() { + std::swap(type_, x.type_); + std::swap(u_, x.u_); + } + +#define IS(ctype, jtype) \ + template <> inline bool value::is() const { \ + return type_ == jtype##_type; \ + } + IS(null, null) + IS(bool, boolean) +#ifdef PICOJSON_USE_INT64 + IS(int64_t, int64) +#endif + IS(std::string, string) + IS(array, array) + IS(object, object) +#undef IS + template <> inline bool value::is() const { + return type_ == number_type +#ifdef PICOJSON_USE_INT64 + || type_ == int64_type +#endif + ; + } + +#define GET(ctype, var) \ + template <> inline const ctype& value::get() const { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" \ + && is()); \ + return var; \ + } \ + template <> inline ctype& value::get() { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" \ + && is()); \ + return var; \ + } + GET(bool, u_.boolean_) + GET(std::string, *u_.string_) + GET(array, *u_.array_) + GET(object, *u_.object_) +#ifdef PICOJSON_USE_INT64 + GET(double, (type_ == int64_type && (const_cast(this)->type_ = number_type, const_cast(this)->u_.number_ = u_.int64_), u_.number_)) + GET(int64_t, u_.int64_) +#else + GET(double, u_.number_) +#endif +#undef GET + +#define SET(ctype, jtype, setter) \ + template <> inline void value::set(const ctype &_val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } + SET(bool, boolean, u_.boolean_ = _val;) + SET(std::string, string, u_.string_ = new std::string(_val);) + SET(array, array, u_.array_ = new array(_val);) + SET(object, object, u_.object_ = new object(_val);) + SET(double, number, u_.number_ = _val;) +#ifdef PICOJSON_USE_INT64 + SET(int64_t, int64, u_.int64_ = _val;) +#endif +#undef SET + +#if PICOJSON_USE_RVALUE_REFERENCE +#define MOVESET(ctype, jtype, setter) \ + template <> inline void value::set(ctype &&_val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } + MOVESET(std::string, string, u_.string_ = new std::string(std::move(_val));) + MOVESET(array, array, u_.array_ = new array(std::move(_val));) + MOVESET(object, object, u_.object_ = new object(std::move(_val));) +#undef MOVESET +#endif + + inline bool value::evaluate_as_boolean() const { + switch (type_) { + case null_type: + return false; + case boolean_type: + return u_.boolean_; + case number_type: + return u_.number_ != 0; +#ifdef PICOJSON_USE_INT64 + case int64_type: + return u_.int64_ != 0; +#endif + case string_type: + return ! u_.string_->empty(); + default: + return true; + } + } + + inline const value& value::get(size_t idx) const { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; + } + + inline value& value::get(size_t idx) { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; + } + + inline const value& value::get(const std::string& key) const { + static value s_null; + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; + } + + inline value& value::get(const std::string& key) { + static value s_null; + PICOJSON_ASSERT(is()); + object::iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; + } + + inline bool value::contains(size_t idx) const { + PICOJSON_ASSERT(is()); + return idx < u_.array_->size(); + } + + inline bool value::contains(const std::string& key) const { + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end(); + } + + inline std::string value::to_str() const { + switch (type_) { + case null_type: return "null"; + case boolean_type: return u_.boolean_ ? "true" : "false"; +#ifdef PICOJSON_USE_INT64 + case int64_type: { + char buf[sizeof("-9223372036854775808")]; + SNPRINTF(buf, sizeof(buf), "%" PRId64, u_.int64_); + return buf; + } +#endif + case number_type: { + char buf[256]; + double tmp; + SNPRINTF(buf, sizeof(buf), fabs(u_.number_) < (1ULL << 53) && modf(u_.number_, &tmp) == 0 ? "%.f" : "%.17g", u_.number_); +#if PICOJSON_USE_LOCALE + char *decimal_point = localeconv()->decimal_point; + if (strcmp(decimal_point, ".") != 0) { + size_t decimal_point_len = strlen(decimal_point); + for (char *p = buf; *p != '\0'; ++p) { + if (strncmp(p, decimal_point, decimal_point_len) == 0) { + return std::string(buf, p) + "." + (p + decimal_point_len); + } + } + } +#endif + return buf; + } + case string_type: return *u_.string_; + case array_type: return "array"; + case object_type: return "object"; + default: PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + } + return std::string(); + } + + template void copy(const std::string& s, Iter oi) { + std::copy(s.begin(), s.end(), oi); + } + + template + struct serialize_str_char { + Iter oi; + void operator()(char c) { + switch (c) { +#define MAP(val, sym) case val: copy(sym, oi); break + MAP('"', "\\\""); + MAP('\\', "\\\\"); + MAP('/', "\\/"); + MAP('\b', "\\b"); + MAP('\f', "\\f"); + MAP('\n', "\\n"); + MAP('\r', "\\r"); + MAP('\t', "\\t"); +#undef MAP + default: + if (static_cast(c) < 0x20 || c == 0x7f) { + char buf[7]; + SNPRINTF(buf, sizeof(buf), "\\u%04x", c & 0xff); + copy(buf, buf + 6, oi); + } else { + *oi++ = c; + } + break; + } + } + }; + + template void serialize_str(const std::string& s, Iter oi) { + *oi++ = '"'; + serialize_str_char process_char = { oi }; + std::for_each(s.begin(), s.end(), process_char); + *oi++ = '"'; + } + + template void value::serialize(Iter oi, bool prettify) const { + return _serialize(oi, prettify ? 0 : -1); + } + + inline std::string value::serialize(bool prettify) const { + return _serialize(prettify ? 0 : -1); + } + + template void value::_indent(Iter oi, int indent) { + *oi++ = '\n'; + for (int i = 0; i < indent * INDENT_WIDTH; ++i) { + *oi++ = ' '; + } + } + + template void value::_serialize(Iter oi, int indent) const { + switch (type_) { + case string_type: + serialize_str(*u_.string_, oi); + break; + case array_type: { + *oi++ = '['; + if (indent != -1) { + ++indent; + } + for (array::const_iterator i = u_.array_->begin(); + i != u_.array_->end(); + ++i) { + if (i != u_.array_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + i->_serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (! u_.array_->empty()) { + _indent(oi, indent); + } + } + *oi++ = ']'; + break; + } + case object_type: { + *oi++ = '{'; + if (indent != -1) { + ++indent; + } + for (object::const_iterator i = u_.object_->begin(); + i != u_.object_->end(); + ++i) { + if (i != u_.object_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + serialize_str(i->first, oi); + *oi++ = ':'; + if (indent != -1) { + *oi++ = ' '; + } + i->second._serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (! u_.object_->empty()) { + _indent(oi, indent); + } + } + *oi++ = '}'; + break; + } + default: + copy(to_str(), oi); + break; + } + if (indent == 0) { + *oi++ = '\n'; + } + } + + inline std::string value::_serialize(int indent) const { + std::string s; + _serialize(std::back_inserter(s), indent); + return s; + } + + template class input { + protected: + Iter cur_, end_; + bool consumed_; + int line_; + public: + input(const Iter& first, const Iter& last) : cur_(first), end_(last), consumed_(false), line_(1) {} + int getc() { + if (consumed_) { + if (*cur_ == '\n') { + ++line_; + } + ++cur_; + } + if (cur_ == end_) { + consumed_ = false; + return -1; + } + consumed_ = true; + return *cur_ & 0xff; + } + void ungetc() { + consumed_ = false; + } + Iter cur() const { + if (consumed_) { + input *self = const_cast*>(this); + self->consumed_ = false; + ++self->cur_; + } + return cur_; + } + int line() const { return line_; } + void skip_ws() { + while (1) { + int ch = getc(); + if (! (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) { + ungetc(); + break; + } + } + } + bool expect(int expect) { + skip_ws(); + if (getc() != expect) { + ungetc(); + return false; + } + return true; + } + bool match(const std::string& pattern) { + for (std::string::const_iterator pi(pattern.begin()); + pi != pattern.end(); + ++pi) { + if (getc() != *pi) { + ungetc(); + return false; + } + } + return true; + } + }; + + template inline int _parse_quadhex(input &in) { + int uni_ch = 0, hex; + for (int i = 0; i < 4; i++) { + if ((hex = in.getc()) == -1) { + return -1; + } + if ('0' <= hex && hex <= '9') { + hex -= '0'; + } else if ('A' <= hex && hex <= 'F') { + hex -= 'A' - 0xa; + } else if ('a' <= hex && hex <= 'f') { + hex -= 'a' - 0xa; + } else { + in.ungetc(); + return -1; + } + uni_ch = uni_ch * 16 + hex; + } + return uni_ch; + } + + template inline bool _parse_codepoint(String& out, input& in) { + int uni_ch; + if ((uni_ch = _parse_quadhex(in)) == -1) { + return false; + } + if (0xd800 <= uni_ch && uni_ch <= 0xdfff) { + if (0xdc00 <= uni_ch) { + // a second 16-bit of a surrogate pair appeared + return false; + } + // first 16-bit of surrogate pair, get the next one + if (in.getc() != '\\' || in.getc() != 'u') { + in.ungetc(); + return false; + } + int second = _parse_quadhex(in); + if (! (0xdc00 <= second && second <= 0xdfff)) { + return false; + } + uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff); + uni_ch += 0x10000; + } + if (uni_ch < 0x80) { + out.push_back(uni_ch); + } else { + if (uni_ch < 0x800) { + out.push_back(0xc0 | (uni_ch >> 6)); + } else { + if (uni_ch < 0x10000) { + out.push_back(0xe0 | (uni_ch >> 12)); + } else { + out.push_back(0xf0 | (uni_ch >> 18)); + out.push_back(0x80 | ((uni_ch >> 12) & 0x3f)); + } + out.push_back(0x80 | ((uni_ch >> 6) & 0x3f)); + } + out.push_back(0x80 | (uni_ch & 0x3f)); + } + return true; + } + + template inline bool _parse_string(String& out, input& in) { + while (1) { + int ch = in.getc(); + if (ch < ' ') { + in.ungetc(); + return false; + } else if (ch == '"') { + return true; + } else if (ch == '\\') { + if ((ch = in.getc()) == -1) { + return false; + } + switch (ch) { +#define MAP(sym, val) case sym: out.push_back(val); break + MAP('"', '\"'); + MAP('\\', '\\'); + MAP('/', '/'); + MAP('b', '\b'); + MAP('f', '\f'); + MAP('n', '\n'); + MAP('r', '\r'); + MAP('t', '\t'); +#undef MAP + case 'u': + if (! _parse_codepoint(out, in)) { + return false; + } + break; + default: + return false; + } + } else { + out.push_back(ch); + } + } + return false; + } + + template inline bool _parse_array(Context& ctx, input& in) { + if (! ctx.parse_array_start()) { + return false; + } + size_t idx = 0; + if (in.expect(']')) { + return ctx.parse_array_stop(idx); + } + do { + if (! ctx.parse_array_item(in, idx)) { + return false; + } + idx++; + } while (in.expect(',')); + return in.expect(']') && ctx.parse_array_stop(idx); + } + + template inline bool _parse_object(Context& ctx, input& in) { + if (! ctx.parse_object_start()) { + return false; + } + if (in.expect('}')) { + return true; + } + do { + std::string key; + if (! in.expect('"') + || ! _parse_string(key, in) + || ! in.expect(':')) { + return false; + } + if (! ctx.parse_object_item(in, key)) { + return false; + } + } while (in.expect(',')); + return in.expect('}'); + } + + template inline std::string _parse_number(input& in) { + std::string num_str; + while (1) { + int ch = in.getc(); + if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' + || ch == 'e' || ch == 'E') { + num_str.push_back(ch); + } else if (ch == '.') { +#if PICOJSON_USE_LOCALE + num_str += localeconv()->decimal_point; +#else + num_str.push_back('.'); +#endif + } else { + in.ungetc(); + break; + } + } + return num_str; + } + + template inline bool _parse(Context& ctx, input& in) { + in.skip_ws(); + int ch = in.getc(); + switch (ch) { +#define IS(ch, text, op) case ch: \ + if (in.match(text) && op) { \ + return true; \ + } else { \ + return false; \ + } + IS('n', "ull", ctx.set_null()); + IS('f', "alse", ctx.set_bool(false)); + IS('t', "rue", ctx.set_bool(true)); +#undef IS + case '"': + return ctx.parse_string(in); + case '[': + return _parse_array(ctx, in); + case '{': + return _parse_object(ctx, in); + default: + if (('0' <= ch && ch <= '9') || ch == '-') { + double f; + char *endp; + in.ungetc(); + std::string num_str = _parse_number(in); + if (num_str.empty()) { + return false; + } +#ifdef PICOJSON_USE_INT64 + { + errno = 0; + intmax_t ival = strtoimax(num_str.c_str(), &endp, 10); + if (errno == 0 + && std::numeric_limits::min() <= ival + && ival <= std::numeric_limits::max() + && endp == num_str.c_str() + num_str.size()) { + ctx.set_int64(ival); + return true; + } + } +#endif + f = strtod(num_str.c_str(), &endp); + if (endp == num_str.c_str() + num_str.size()) { + ctx.set_number(f); + return true; + } + return false; + } + break; + } + in.ungetc(); + return false; + } + + class deny_parse_context { + public: + bool set_null() { return false; } + bool set_bool(bool) { return false; } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { return false; } +#endif + bool set_number(double) { return false; } + template bool parse_string(input&) { return false; } + bool parse_array_start() { return false; } + template bool parse_array_item(input&, size_t) { + return false; + } + bool parse_array_stop(size_t) { return false; } + bool parse_object_start() { return false; } + template bool parse_object_item(input&, const std::string&) { + return false; + } + }; + + class default_parse_context { + protected: + value* out_; + public: + default_parse_context(value* out) : out_(out) {} + bool set_null() { + *out_ = value(); + return true; + } + bool set_bool(bool b) { + *out_ = value(b); + return true; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t i) { + *out_ = value(i); + return true; + } +#endif + bool set_number(double f) { + *out_ = value(f); + return true; + } + template bool parse_string(input& in) { + *out_ = value(string_type, false); + return _parse_string(out_->get(), in); + } + bool parse_array_start() { + *out_ = value(array_type, false); + return true; + } + template bool parse_array_item(input& in, size_t) { + array& a = out_->get(); + a.push_back(value()); + default_parse_context ctx(&a.back()); + return _parse(ctx, in); + } + bool parse_array_stop(size_t) { return true; } + bool parse_object_start() { + *out_ = value(object_type, false); + return true; + } + template bool parse_object_item(input& in, const std::string& key) { + object& o = out_->get(); + default_parse_context ctx(&o[key]); + return _parse(ctx, in); + } + private: + default_parse_context(const default_parse_context&); + default_parse_context& operator=(const default_parse_context&); + }; + + class null_parse_context { + public: + struct dummy_str { + void push_back(int) {} + }; + public: + null_parse_context() {} + bool set_null() { return true; } + bool set_bool(bool) { return true; } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { return true; } +#endif + bool set_number(double) { return true; } + template bool parse_string(input& in) { + dummy_str s; + return _parse_string(s, in); + } + bool parse_array_start() { return true; } + template bool parse_array_item(input& in, size_t) { + return _parse(*this, in); + } + bool parse_array_stop(size_t) { return true; } + bool parse_object_start() { return true; } + template bool parse_object_item(input& in, const std::string&) { + return _parse(*this, in); + } + private: + null_parse_context(const null_parse_context&); + null_parse_context& operator=(const null_parse_context&); + }; + + // obsolete, use the version below + template inline std::string parse(value& out, Iter& pos, const Iter& last) { + std::string err; + pos = parse(out, pos, last, &err); + return err; + } + + template inline Iter _parse(Context& ctx, const Iter& first, const Iter& last, std::string* err) { + input in(first, last); + if (! _parse(ctx, in) && err != NULL) { + char buf[64]; + SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line()); + *err = buf; + while (1) { + int ch = in.getc(); + if (ch == -1 || ch == '\n') { + break; + } else if (ch >= ' ') { + err->push_back(ch); + } + } + } + return in.cur(); + } + + template inline Iter parse(value& out, const Iter& first, const Iter& last, std::string* err) { + default_parse_context ctx(&out); + return _parse(ctx, first, last, err); + } + + inline std::string parse(value& out, const std::string& s) { + std::string err; + parse(out, s.begin(), s.end(), &err); + return err; + } + + inline std::string parse(value& out, std::istream& is) { + std::string err; + parse(out, std::istreambuf_iterator(is.rdbuf()), + std::istreambuf_iterator(), &err); + return err; + } + + template struct last_error_t { + static std::string s; + }; + template std::string last_error_t::s; + + inline void set_last_error(const std::string& s) { + last_error_t::s = s; + } + + inline const std::string& get_last_error() { + return last_error_t::s; + } + + inline bool operator==(const value& x, const value& y) { + if (x.is()) + return y.is(); +#define PICOJSON_CMP(type) \ + if (x.is()) \ + return y.is() && x.get() == y.get() + PICOJSON_CMP(bool); + PICOJSON_CMP(double); + PICOJSON_CMP(std::string); + PICOJSON_CMP(array); + PICOJSON_CMP(object); +#undef PICOJSON_CMP + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + return false; + } + + inline bool operator!=(const value& x, const value& y) { + return ! (x == y); + } +} + +#if !PICOJSON_USE_RVALUE_REFERENCE +namespace std { + template<> inline void swap(picojson::value& x, picojson::value& y) + { + x.swap(y); + } +} +#endif + +inline std::istream& operator>>(std::istream& is, picojson::value& x) +{ + picojson::set_last_error(std::string()); + std::string err = picojson::parse(x, is); + if (! err.empty()) { + picojson::set_last_error(err); + is.setstate(std::ios::failbit); + } + return is; +} + +inline std::ostream& operator<<(std::ostream& os, const picojson::value& x) +{ + x.serialize(std::ostream_iterator(os)); + return os; +} +#ifdef _MSC_VER + #pragma warning(pop) +#endif + +#endif diff --git a/tests/conanfile_deployment.py b/tests/conanfile_deployment.py index 22e91d3ae..a16f63e49 100644 --- a/tests/conanfile_deployment.py +++ b/tests/conanfile_deployment.py @@ -29,6 +29,7 @@ class CloeStandardDeployment(ConanFile): "fable:allow_comments": True, "cloe-engine:server": True, + "cloe-engine:lrdb": True, } @property From a0b85697a7d27ea1fc80f2afe384aa46cd8399f4 Mon Sep 17 00:00:00 2001 From: Benjamin Morgan Date: Thu, 16 May 2024 18:33:28 +0200 Subject: [PATCH 10/22] all: Add DataBroker and Signals concepts TODO: - Add documentation on how to use it. - Add tests using them. Author: Martin Henselmeyer --- engine/lua/cloe-engine/init.lua | 9 + engine/lua/cloe/engine.lua | 152 ++ engine/src/coordinator.cpp | 4 +- engine/src/coordinator.hpp | 5 +- engine/src/registrar.hpp | 12 +- engine/src/simulation.cpp | 273 ++- engine/src/simulation.hpp | 3 + engine/src/simulation_context.hpp | 2 + engine/src/stack.hpp | 4 + engine/src/stack_test.cpp | 2 + models/CMakeLists.txt | 4 + models/conanfile.py | 1 + models/include/cloe/component/object.hpp | 19 + models/include/cloe/component/wheel.hpp | 8 + models/include/cloe/utility/lua_types.hpp | 44 + models/src/cloe/utility/lua_types.cpp | 393 ++++ models/src/cloe/utility/lua_types_test.cpp | 118 ++ plugins/basic/src/basic.cpp | 42 + runtime/CMakeLists.txt | 11 + runtime/include/cloe/cloe_fwd.hpp | 7 + runtime/include/cloe/data_broker.hpp | 2063 ++++++++++++++++++++ runtime/include/cloe/registrar.hpp | 3 + runtime/src/cloe/data_broker.cpp | 36 + runtime/src/cloe/data_broker_test.cpp | 1025 ++++++++++ tests/test_engine_json_schema.json | 24 + tests/test_engine_nop_smoketest_dump.json | 1 + tests/test_lua04_schedule_test.lua | 28 +- 27 files changed, 4277 insertions(+), 16 deletions(-) create mode 100644 models/include/cloe/utility/lua_types.hpp create mode 100644 models/src/cloe/utility/lua_types.cpp create mode 100644 models/src/cloe/utility/lua_types_test.cpp create mode 100644 runtime/include/cloe/data_broker.hpp create mode 100644 runtime/src/cloe/data_broker.cpp create mode 100644 runtime/src/cloe/data_broker_test.cpp diff --git a/engine/lua/cloe-engine/init.lua b/engine/lua/cloe-engine/init.lua index a53ce533d..ce16073b5 100644 --- a/engine/lua/cloe-engine/init.lua +++ b/engine/lua/cloe-engine/init.lua @@ -34,6 +34,12 @@ local engine = { --- @type number Number of triggers processed from the initial input. triggers_processed = 0, + + --- @type table Map of signal names to regular expression matches. + signal_aliases = {}, + + --- @type string[] List of signals to make available during simulation. + signal_requires = {}, }, --- Contains engine state for a simulation. @@ -80,6 +86,9 @@ local engine = { --- @type table Namespaced Lua interfaces of instantiated plugins. plugins = {}, + + --- @type table Table of required signals. + signals = {}, } require("cloe-engine.types") diff --git a/engine/lua/cloe/engine.lua b/engine/lua/cloe/engine.lua index 8a3e86df6..6e1d954ea 100644 --- a/engine/lua/cloe/engine.lua +++ b/engine/lua/cloe/engine.lua @@ -124,6 +124,158 @@ function engine.log(level, fmt, ...) api.log(level, "lua", msg) end +--- Alias a set of signals in the Cloe data broker. +--- +--- @param list table # regular expression to alias key +--- @return table # current signal aliases table +function engine.alias_signals(list) + -- TODO: Throw an error if simulation already started. + api.initial_input.signal_aliases = luax.tbl_extend("force", api.initial_input.signal_aliases, list) + return api.initial_input.signal_aliases +end + +--- Require a set of signals to be made available via the Cloe data broker. +--- +--- @param list string[] signals to merge into main list of required signals +--- @return string[] # merged list of signals +function engine.require_signals(list) + -- TODO: Throw an error if simulation already started. + api.initial_input.signal_requires = luax.tbl_extend("force", api.initial_input.signal_requires, list) + return api.initial_input.signal_requires +end + +--- Optionally alias and require a set of signals from a signals enum list. +--- +--- This allows you to make an enum somewhere which the language server +--- can use for autocompletion and which you can use as an alias: +--- +--- ---@enum Sig +--- local Sig = { +--- DriverDoorLatch = "vehicle::framework::chassis::.*driver_door::latch", +--- VehicleMps = "vehicle::sensors::chassis::velocity", +--- } +--- cloe.require_signals_enum(Sig, true) +--- +--- Later, you can use the enum with `cloe.signal()`: +--- +--- cloe.signal(Sig.DriverDoorLatch) +--- +--- @param enum table input mappging from enum name to signal name +--- @param alias boolean whether to treat signal names as alias regular expressions +--- @return nil +function engine.require_signals_enum(enum, alias) + -- TODO: Throw an error if simulation already started. + local signals = {} + if alias then + local aliases = {} + for key, sigregex in pairs(enum) do + table.insert(aliases, { sigregex, key }) + table.insert(signals, key) + end + engine.alias_signals(aliases) + else + for _, signame in pairs(enum) do + table.insert(signals, signame) + end + end + engine.require_signals(signals) +end + +--- Return full list of loaded signals. +--- +--- Example: +--- +--- local signals = cloe.signals() +--- signals[SigName] = value +--- +--- @return table +function engine.signals() + return api.signals +end + +--- Return the specified signal. +--- +--- If the signal does not exist, nil is returned. +--- +--- If you want to set the signal, you need to use `cloe.set_signal()` +--- or access the value via `cloe.signals()`. +--- +--- @param name string signal name +--- @return any|nil # signal value +function engine.signal(name) + return api.signals[name] +end + +--- Set the specified signal with a value. +--- +--- @param name string signal name +--- @param value any signal value +--- @return nil +function engine.set_signal(name, value) + api.signals[name] = value +end + +--- Record the given list of signals into the report. +--- +--- This can be called multiple times, but if the signal is already +--- being recorded, then an error will be raised. +--- +--- This should be called before simulation starts, +--- so not from a scheduled callback. +--- +--- You can pass it a list of signals to record, or a mapping +--- from name to +--- +--- @param mapping table mapping from signal names +--- @return nil +function engine.record_signals(mapping) + validate("cloe.record_signals(table)", mapping) + api.state.report.signals = api.state.report.signals or {} + local signals = api.state.report.signals + signals.time = signals.time or {} + for sig, getter in pairs(mapping) do + if type(sig) == "number" then + if type(getter) ~= "string" then + error("positional signals can only be signal names") + end + sig = getter + end + if signals[sig] then + error("signal already exists: " .. sig) + end + signals[sig] = {} + end + + engine.schedule({ + on = "loop", + pin = true, + run = function(sync) + local last_time = signals.time[#signals.time] + local cur_time = sync:time():ms() + if last_time ~= cur_time then + table.insert(signals.time, cur_time) + end + + for name, getter in pairs(mapping) do + local value + if type(name) == "number" then + name = getter + end + if type(getter) == "string" then + value = engine.signal(getter) + else + value = getter() + end + if value == nil then + -- TODO: Improve error message! + error("nil value received as signal value") + end + table.insert(signals[name], value) + end + end, + }) +end + --- Schedule a trigger. --- --- It is not recommended to use this low-level function, as it is viable to change. diff --git a/engine/src/coordinator.cpp b/engine/src/coordinator.cpp index 31f91c8a4..3c40dd166 100644 --- a/engine/src/coordinator.cpp +++ b/engine/src/coordinator.cpp @@ -56,8 +56,8 @@ void to_json(Json& j, const HistoryTrigger& t) { j["at"] = t.when; } -Coordinator::Coordinator(sol::state_view lua) - : lua_(lua), executer_registrar_(trigger_registrar(Source::TRIGGER)) {} +Coordinator::Coordinator(sol::state_view lua, cloe::DataBroker* db) + : lua_(lua), executer_registrar_(trigger_registrar(Source::TRIGGER)), db_(db) {} class TriggerRegistrar : public cloe::TriggerRegistrar { public: diff --git a/engine/src/coordinator.hpp b/engine/src/coordinator.hpp index 7bf21605b..bace461f5 100644 --- a/engine/src/coordinator.hpp +++ b/engine/src/coordinator.hpp @@ -98,7 +98,7 @@ struct HistoryTrigger { */ class Coordinator { public: - Coordinator(sol::state_view lua); + Coordinator(sol::state_view lua, cloe::DataBroker* db); const std::vector& history() const { return history_; } @@ -109,6 +109,8 @@ class Coordinator { sol::table register_lua_table(const std::string& field); + cloe::DataBroker* data_broker() const { return db_; } + std::shared_ptr trigger_registrar(cloe::Source s); void enroll(cloe::Registrar& r); @@ -149,6 +151,7 @@ class Coordinator { std::map actions_; std::map events_; sol::state_view lua_; + cloe::DataBroker* db_; // non-owning // Execution: std::shared_ptr executer_registrar_; diff --git a/engine/src/registrar.hpp b/engine/src/registrar.hpp index bba820b1d..cc88edfc8 100644 --- a/engine/src/registrar.hpp +++ b/engine/src/registrar.hpp @@ -33,14 +33,14 @@ namespace engine { class Registrar : public cloe::Registrar { public: - Registrar(std::unique_ptr r, Coordinator* c) - : server_registrar_(std::move(r)), coordinator_(c) {} + Registrar(std::unique_ptr r, Coordinator* c, cloe::DataBroker* db) + : server_registrar_(std::move(r)), coordinator_(c), data_broker_(db) {} Registrar(const Registrar& ar, const std::string& trigger_prefix, const std::string& static_prefix, const std::string& api_prefix) - : coordinator_(ar.coordinator_) { + : coordinator_(ar.coordinator_), data_broker_(ar.data_broker_) { if (trigger_prefix.empty()) { trigger_prefix_ = ar.trigger_prefix_; } else { @@ -107,9 +107,15 @@ class Registrar : public cloe::Registrar { return coordinator_->register_lua_table(trigger_prefix_); } + cloe::DataBroker& data_broker() const override { + assert(data_broker_ != nullptr); + return *data_broker_; + } + private: std::unique_ptr server_registrar_; Coordinator* coordinator_; // non-owning + cloe::DataBroker* data_broker_; // non-owning std::string trigger_prefix_; }; diff --git a/engine/src/simulation.cpp b/engine/src/simulation.cpp index 81ac83082..c82e00501 100644 --- a/engine/src/simulation.cpp +++ b/engine/src/simulation.cpp @@ -86,6 +86,7 @@ #include // for Controller #include // for AsyncAbort +#include // for DataBroker #include // for DirectCallback #include // for Simulator #include // for CommandFactory, BundleFactory, ... @@ -758,12 +759,208 @@ size_t insert_triggers_from_config(SimulationContext& ctx) { return count; } +/** + * Pseudo-class which hosts the Cloe-Signals as properties inside of the Lua-VM + */ +class LuaCloeSignal {}; + StateId SimulationMachine::Start::impl(SimulationContext& ctx) { logger()->info("Starting simulation..."); // Begin execution progress ctx.progress.exec_begin(); + { + // Bind lua state_view to databroker + auto* dbPtr = ctx.coordinator->data_broker(); + if (!dbPtr) { + throw std::logic_error("Coordinator did not provide a DataBroker instance"); + } + auto& db = *dbPtr; + // Alias signals via lua + { + bool aliasing_failure = false; + // Read cloe.alias_signals + sol::object signal_aliases = cloe::luat_cloe_engine_initial_input(ctx.lua)["signal_aliases"]; + auto type = signal_aliases.get_type(); + switch (type) { + // cloe.alias_signals: expected is a list (i.e. table) of 2-tuple each strings + case sol::type::table: { + sol::table alias_signals = signal_aliases.as(); + auto tbl_size = std::distance(alias_signals.begin(), alias_signals.end()); + //for (auto& kv : alias_signals) + for (int i = 0; i < tbl_size; i++) { + //sol::object value = kv.second; + sol::object value = alias_signals[i + 1]; + sol::type type = value.get_type(); + switch (type) { + // cloe.alias_signals[i]: expected is a 2-tuple (i.e. table) each strings + case sol::type::table: { + sol::table alias_tuple = value.as(); + auto tbl_size = std::distance(alias_tuple.begin(), alias_tuple.end()); + if (tbl_size != 2) { + // clang-format off + logger()->error( + "One or more entries in 'cloe.alias_signals' does not consist out of a 2-tuple. " + "Expected are entries in this format { \"regex\" , \"short-name\" }" + ); + // clang-format on + aliasing_failure = true; + continue; + } + + sol::object value; + sol::type type; + std::string old_name; + std::string alias_name; + value = alias_tuple[1]; + type = value.get_type(); + if (sol::type::string != type) { + // clang-format off + logger()->error( + "One or more parts in a tuple in 'cloe.alias_signals' has an unexpected datatype '{}'. " + "Expected are entries in this format { \"regex\" , \"short-name\" }", + static_cast(type)); + // clang-format on + aliasing_failure = true; + } else { + old_name = value.as(); + } + + value = alias_tuple[2]; + type = value.get_type(); + if (sol::type::string != type) { + // clang-format off + logger()->error( + "One or more parts in a tuple in 'cloe.alias_signals' has an unexpected datatype '{}'. " + "Expected are entries in this format { \"regex\" , \"short-name\" }", + static_cast(type)); + // clang-format on + aliasing_failure = true; + } else { + alias_name = value.as(); + } + try { + db.alias(old_name, alias_name); + // clang-format off + logger()->info( + "Aliasing signal '{}' as '{}'.", + old_name, alias_name); + // clang-format on + } catch (const std::logic_error& ex) { + // clang-format off + logger()->error( + "Aliasing signal specifier '{}' as '{}' failed with this error: {}", + old_name, alias_name, ex.what()); + // clang-format on + aliasing_failure = true; + } catch (...) { + // clang-format off + logger()->error( + "Aliasing signal specifier '{}' as '{}' failed.", + old_name, alias_name); + // clang-format on + aliasing_failure = true; + } + } break; + // cloe.alias_signals[i]: is not a table + default: { + // clang-format off + logger()->error( + "One or more entries in 'cloe.alias_signals' has an unexpected datatype '{}'. " + "Expected are entries in this format { \"regex\" , \"short-name\" }", + static_cast(type)); + // clang-format on + aliasing_failure = true; + } break; + } + } + + } break; + case sol::type::none: + case sol::type::lua_nil: { + // not defined -> nop + } break; + default: { + // clang-format off + logger()->error( + "Expected symbol 'cloe.alias_signals' has unexpected datatype '{}'. " + "Expected is a list of 2-tuples in this format { \"regex\" , \"short-name\" }", + static_cast(type)); + // clang-format on + aliasing_failure = true; + } break; + } + if (aliasing_failure) { + throw cloe::ModelError("Aliasing signals failed with above error. Aborting."); + } + } + + // Inject requested signals into lua + { + auto& signals = db.signals(); + bool binding_failure = false; + // Read cloe.require_signals + sol::object value = cloe::luat_cloe_engine_initial_input(ctx.lua)["signal_requires"]; + auto type = value.get_type(); + switch (type) { + // cloe.require_signals expected is a list (i.e. table) of strings + case sol::type::table: { + sol::table require_signals = value.as(); + auto tbl_size = std::distance(require_signals.begin(), require_signals.end()); + + for (int i = 0; i < tbl_size; i++) { + sol::object value = require_signals[i + 1]; + + sol::type type = value.get_type(); + if (type != sol::type::string) { + logger()->warn( + "One entry of cloe.require_signals has a wrong data type: '{}'. " + "Expected is a list of strings.", + static_cast(type)); + binding_failure = true; + continue; + } + std::string signal_name = value.as(); + + // virtually bind signal 'signal_name' to lua + auto iter = db[signal_name]; + if (iter != signals.end()) { + try { + db.bind_signal(signal_name); + logger()->info("Binding signal '{}' as '{}'.", signal_name, signal_name); + } catch (const std::logic_error& ex) { + logger()->error("Binding signal '{}' failed with error: {}", signal_name, + ex.what()); + } + } else { + logger()->warn("Requested signal '{}' does not exist in DataBroker.", signal_name); + binding_failure = true; + } + } + // actually bind all virtually bound signals to lua + db.bind("signals", cloe::luat_cloe_engine(ctx.lua)); + } break; + case sol::type::none: + case sol::type::lua_nil: { + logger()->warn( + "Expected symbol 'cloe.require_signals' appears to be undefined. " + "Expected is a list of string."); + } break; + default: { + logger()->error( + "Expected symbol 'cloe.require_signals' has unexpected datatype '{}'. " + "Expected is a list of string.", + static_cast(type)); + binding_failure = true; + } break; + } + if (binding_failure) { + throw cloe::ModelError("Binding signals to Lua failed with above error. Aborting."); + } + } + } + // Process initial trigger list insert_triggers_from_config(ctx); ctx.coordinator->process_pending_lua_triggers(ctx.sync); @@ -1211,12 +1408,75 @@ Simulation::Simulation(cloe::Stack&& config, sol::state&& lua, const std::string , logger_(cloe::logger::get("cloe")) , uuid_(uuid) {} +struct SignalReport { + std::string name; + std::vector names; + + friend void to_json(cloe::Json& j, const SignalReport& r) { + j = cloe::Json{ + {"name", r.name}, + {"names", r.names}, + }; + } +}; +struct SignalsReport { + std::vector signals; + + friend void to_json(cloe::Json& j, const SignalsReport& r) { + j = cloe::Json{ + {"signals", r.signals}, + }; + } +}; + +cloe::Json dump_signals(cloe::DataBroker& db) { + SignalsReport report; + + const auto& signals = db.signals(); + for (const auto& [key, signal] : signals) { + // create signal + auto& signalreport = report.signals.emplace_back(); + // copy the signal-names + signalreport.name = key; + std::copy(signal->names().begin(), signal->names().end(), + std::back_inserter(signalreport.names)); + + const auto& metadata = signal->metadatas(); + } + + auto json = cloe::Json{report}; + return json; +} + +std::vector dump_signals_autocompletion(cloe::DataBroker& db) { + auto result = std::vector{}; + result.emplace_back("--- @meta"); + result.emplace_back("--- @class signals"); + + const auto& signals = db.signals(); + for (const auto& [key, signal] : signals) { + const auto* tag = signal->metadata(); + if (tag) { + const auto lua_type = to_string(tag->datatype); + const auto& lua_helptext = tag->text; + auto line = fmt::format("--- @field {} {} {}", key, lua_type, lua_helptext); + result.emplace_back(std::move(line)); + } else { + auto line = fmt::format("--- @field {}", key); + result.emplace_back(std::move(line)); + } + } + return result; +} + SimulationResult Simulation::run() { // Input: SimulationContext ctx{lua_.lua_state()}; + ctx.db = std::make_unique(ctx.lua); ctx.server = make_server(config_.server); - ctx.coordinator = std::make_unique(ctx.lua); - ctx.registrar = std::make_unique(ctx.server->server_registrar(), ctx.coordinator.get()); + ctx.coordinator = std::make_unique(ctx.lua, ctx.db.get()); + ctx.registrar = std::make_unique(ctx.server->server_registrar(), ctx.coordinator.get(), + ctx.db.get()); ctx.commander = std::make_unique(logger()); ctx.sync = SimulationSync(config_.simulation.model_step_width); ctx.config = config_; @@ -1325,6 +1585,13 @@ SimulationResult Simulation::run() { r.elapsed = ctx.progress.elapsed(); r.triggers = ctx.coordinator->history(); r.report = sol::object(cloe::luat_cloe_engine_state(ctx.lua)["report"]); + // Don't create output file data unless the output files are being written + if (ctx.config.engine.output_file_signals) { + r.signals = dump_signals(*ctx.db); + } + if (ctx.config.engine.output_file_signals_autocompletion) { + r.signals_autocompletion = dump_signals_autocompletion(*ctx.db); + } abort_fn_ = nullptr; return r; @@ -1350,6 +1617,8 @@ size_t Simulation::write_output(const SimulationResult& r) const { write_file(r.config.engine.output_file_result, r); write_file(r.config.engine.output_file_config, r.config); write_file(r.config.engine.output_file_triggers, r.triggers); + write_file(r.config.engine.output_file_signals, r.signals); + write_file(r.config.engine.output_file_signals_autocompletion, r.signals_autocompletion); logger()->info("Wrote {} output files.", files_written); return files_written; diff --git a/engine/src/simulation.hpp b/engine/src/simulation.hpp index ce2105861..595304ef9 100644 --- a/engine/src/simulation.hpp +++ b/engine/src/simulation.hpp @@ -46,6 +46,9 @@ struct SimulationResult { SimulationStatistics statistics; cloe::Json triggers; cloe::Json report; + cloe::Json signals; // dump of all signals in DataBroker right before the simulation started + std::vector + signals_autocompletion; // pseudo lua file used for vscode autocompletion boost::optional output_dir; public: diff --git a/engine/src/simulation_context.hpp b/engine/src/simulation_context.hpp index 9bdbc36ff..cb2fd73b2 100644 --- a/engine/src/simulation_context.hpp +++ b/engine/src/simulation_context.hpp @@ -33,6 +33,7 @@ #include // for state_view #include // for Simulator, Controller, Registrar, Vehicle, Duration +#include // for DataBroker #include // for Sync #include // for DEFINE_NIL_EVENT #include // for Accumulator @@ -201,6 +202,7 @@ struct SimulationContext { sol::state_view lua; // Setup + std::unique_ptr db; std::unique_ptr server; std::shared_ptr coordinator; std::shared_ptr registrar; diff --git a/engine/src/stack.hpp b/engine/src/stack.hpp index dc2bdded7..fc73182bb 100644 --- a/engine/src/stack.hpp +++ b/engine/src/stack.hpp @@ -314,6 +314,8 @@ struct EngineConf : public Confable { boost::optional output_file_config{"config.json"}; boost::optional output_file_result{"result.json"}; boost::optional output_file_triggers{"triggers.json"}; + boost::optional output_file_signals{"signals.json"}; + boost::optional output_file_signals_autocompletion; boost::optional output_file_data_stream; bool output_clobber_files{true}; @@ -416,6 +418,8 @@ struct EngineConf : public Confable { {"config", make_schema(&output_file_config, file_proto(), "file to store config in")}, {"result", make_schema(&output_file_result, file_proto(), "file to store simulation result in")}, {"triggers", make_schema(&output_file_triggers, file_proto(), "file to store triggers in")}, + {"signals", make_schema(&output_file_signals, file_proto(), "file to store signals in")}, + {"signals_autocompletion", make_schema(&output_file_signals_autocompletion, file_proto(), "file to store signal autocompletion in")}, {"api_recording", make_schema(&output_file_data_stream, file_proto(), "file to store api data stream")}, }}, }}, diff --git a/engine/src/stack_test.cpp b/engine/src/stack_test.cpp index cee1d60ec..38b3f226f 100644 --- a/engine/src/stack_test.cpp +++ b/engine/src/stack_test.cpp @@ -51,6 +51,7 @@ TEST(cloe_stack, serialization_of_empty_stack) { "files": { "config": "config.json", "result": "result.json", + "signals": "signals.json", "triggers": "triggers.json" } }, @@ -142,6 +143,7 @@ TEST(cloe_stack, serialization_with_logging) { "files": { "config": "config.json", "result": "result.json", + "signals": "signals.json", "triggers": "triggers.json" } }, diff --git a/models/CMakeLists.txt b/models/CMakeLists.txt index 881d7d09d..173e7aacf 100644 --- a/models/CMakeLists.txt +++ b/models/CMakeLists.txt @@ -12,6 +12,7 @@ if(CLOE_FIND_PACKAGES) endif() find_package(Eigen3 REQUIRED QUIET) find_package(Boost COMPONENTS headers REQUIRED QUIET) +find_package(sol2 REQUIRED QUIET) message(STATUS "Building cloe-models library.") file(GLOB cloe-models_PUBLIC_HEADERS "include/**/*.hpp") @@ -21,6 +22,7 @@ add_library(cloe-models src/cloe/component/utility/ego_sensor_canon.cpp src/cloe/component/utility/steering_utils.cpp src/cloe/utility/actuation_state.cpp + src/cloe/utility/lua_types.cpp # For IDE integration ${cloe-models_PUBLIC_HEADERS} @@ -41,6 +43,7 @@ target_link_libraries(cloe-models cloe::runtime Boost::headers Eigen3::Eigen + sol2::sol2 ) # Testing ------------------------------------------------------------- @@ -57,6 +60,7 @@ if(BUILD_TESTING) src/cloe/component/utility/steering_utils_test.cpp src/cloe/utility/actuation_level_test.cpp src/cloe/utility/frustum_culling_test.cpp + src/cloe/utility/lua_types_test.cpp ) set_target_properties(test-models PROPERTIES CXX_STANDARD 17 diff --git a/models/conanfile.py b/models/conanfile.py index f6ab6aa4f..1514aa277 100644 --- a/models/conanfile.py +++ b/models/conanfile.py @@ -44,6 +44,7 @@ def requirements(self): self.requires(f"cloe-runtime/{self.version}@cloe/develop") self.requires("boost/1.74.0") self.requires("eigen/3.4.0") + self.requires("sol2/3.3.1") def build_requirements(self): self.test_requires("gtest/1.14.0") diff --git a/models/include/cloe/component/object.hpp b/models/include/cloe/component/object.hpp index edb94d853..68f4b6474 100644 --- a/models/include/cloe/component/object.hpp +++ b/models/include/cloe/component/object.hpp @@ -24,6 +24,8 @@ #include // for shared_ptr<> #include // for vector<> +#include // for Lua related aspects + #include // for Isometry3d, Vector3d #include // for ENUM_SERIALIZATION #include // for Json @@ -98,6 +100,23 @@ struct Object { {"angular_velocity", o.angular_velocity}, }; } + friend void to_lua(sol::state_view view, Object* /* value */) { + sol::usertype usertype_table = view.new_usertype("Object"); + usertype_table["id"] = &Object::id; + usertype_table["exist_prob"] = &Object::exist_prob; + usertype_table["type"] = &Object::type; + usertype_table["classification"] = &Object::classification; + usertype_table["pose"] = &Object::pose; + usertype_table["dimensions"] = &Object::dimensions; + usertype_table["cog_offset"] = &Object::cog_offset; + usertype_table["velocity"] = &Object::velocity; + usertype_table["acceleration"] = &Object::acceleration; + usertype_table["angular_velocity"] = &Object::angular_velocity; + usertype_table["ego_position"] = +[](const Object &self, const Eigen::Isometry3d &sensorMountPose) { + Eigen::Vector3d pos = sensorMountPose * self.pose * self.cog_offset; + return pos; + }; + } }; /** diff --git a/models/include/cloe/component/wheel.hpp b/models/include/cloe/component/wheel.hpp index b1f910869..f3b2dbd68 100644 --- a/models/include/cloe/component/wheel.hpp +++ b/models/include/cloe/component/wheel.hpp @@ -23,6 +23,8 @@ #include // for to_json +#include // for Lua related aspects + namespace cloe { struct Wheel { @@ -42,6 +44,12 @@ struct Wheel { {"spring_compression", w.spring_compression}, }; } + friend void to_lua(sol::state_view view, Wheel* /* value */) { + sol::usertype usertype_table = view.new_usertype("Wheel"); + usertype_table["rotation"] = &Wheel::rotation; + usertype_table["velocity"] = &Wheel::velocity; + usertype_table["spring_compression"] = &Wheel::spring_compression; + } }; } // namespace cloe diff --git a/models/include/cloe/utility/lua_types.hpp b/models/include/cloe/utility/lua_types.hpp new file mode 100644 index 000000000..b6f695ed4 --- /dev/null +++ b/models/include/cloe/utility/lua_types.hpp @@ -0,0 +1,44 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file cloe/component/ego_sensor.hpp + */ + +#pragma once + +#include + +#include + +#include + +namespace cloe { +namespace utility { + +extern void register_lua_types(cloe::DataBroker& db); + +extern void register_gaspedal_sensor(DataBroker& db, const std::string& vehicle, + std::function gaspedal_getter); +extern void register_wheel_sensor(DataBroker& db, + const std::string& vehicle, + const std::string& wheel_name, + std::function + wheel_getter); + +} // namespace utility +} // namespace cloe diff --git a/models/src/cloe/utility/lua_types.cpp b/models/src/cloe/utility/lua_types.cpp new file mode 100644 index 000000000..bb1fcc761 --- /dev/null +++ b/models/src/cloe/utility/lua_types.cpp @@ -0,0 +1,393 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file cloe/utility/lua_types.cpp + * \see cloe/utility/lua_types.hpp + */ + +#include + +#include + +#include + +namespace sol { +template <> +struct is_automagical : std::false_type {}; +template <> +struct is_automagical : std::false_type {}; +template <> +struct is_automagical : std::false_type {}; +} // namespace sol + +namespace cloe { +namespace utility { + +/** + * Derives matrix type traits from the given typename T + */ +template +struct MatrixTypeTraitsDetail {}; +/** + * Derives matrix type traits from the given typename T + */ +template +struct MatrixTypeTraitsDetail> { + using Scalar = Scalar_; + static constexpr int Rows = Rows_; + static constexpr int Cols = Cols_; +}; + +/** + * Derives the SOL constructors-type from the given typename T + */ +template +struct MatrixCtors {}; +/** + * \brief Derives the SOL constructors-type from the given typename T + * \note Specialization for 2x1 matrices + */ +template +struct MatrixCtors { + using Scalar = typename MatrixTypeTraitsDetail::Scalar; + using Ctors = sol::constructors; +}; +/** + * \brief Derives the SOL constructors-type from the given typename T + * \note Specialization for 3x1 matrices + */ +template +struct MatrixCtors { + using Scalar = typename MatrixTypeTraitsDetail::Scalar; + using Ctors = sol::constructors; +}; +/** + * \brief Derives the SOL constructors-type from the given typename T + * \note Specialization for 4x1 matrices + */ +template +struct MatrixCtors { + using Scalar = typename MatrixTypeTraitsDetail::Scalar; + using Ctors = sol::constructors; +}; + +/** + * Type-Traits for Eigen matrices + */ +template +struct MatrixTypeTraits { + using Scalar = typename MatrixTypeTraitsDetail::Scalar; + using Ctors = typename MatrixCtors::Rows, + MatrixTypeTraitsDetail::Cols>::Ctors; + static constexpr int Rows = MatrixTypeTraitsDetail::Rows; + static constexpr int Cols = MatrixTypeTraitsDetail::Cols; +}; + +/** + * Accessor functions for Eigen matrices + */ +template +struct MatrixAccessor { + template ::Scalar> + static Scalar get(T& matrix) { + return matrix[Row][Col]; + } + template ::Scalar> + static void set(T& matrix, Scalar value) { + matrix[Row][Col] = value; + } +}; + +/** + * Accessor functions for Eigen matrices + */ +template +struct MatrixAccessor { + template ::Scalar> + static Scalar get(T& matrix) { + return matrix[Row]; + } + template ::Scalar> + static void set(T& matrix, Scalar value) { + matrix[Row] = value; + } +}; + +const char* vector_names_xyzw[] = {"x", "y", "z", "w"}; +const char* vector_names_r_phi[] = {"r", "phi", "", ""}; +const char* vector_names_r_theta_phi[] = {"r", "theta", "phi", ""}; +const char* vector_names_rho_eta_phi[] = {"rho", "eta", "phi", ""}; + +const std::vector namespace_eigen = {"eigen"}; + +const std::array namespace_prefix = {"cloe", "types"}; + +/** + * \brief Traverses the global namespace-prefix as well as the given namespace + * \param view Lua state_view + * \param ns_iter Iterator pointing to the beginning of the namspace-array + * \param ns_end Iterator pointing to the end of the namspace-array + * \param table_fn Callback accepting a SOL-table which reflects the given namespace + */ +void traverse_namespace_impl(sol::state_view view, const std::vector& ns, + std::function table_fn) { + const char* name; + sol::table table; + + // traverse the global namespace-prefix + static_assert(namespace_prefix.size() > 0); + { + auto iter = namespace_prefix.cbegin(); + auto end = namespace_prefix.cend(); + name = *iter++; + table = view[name].get_or_create(); + while (iter != end) { + name = *iter++; + table = table[name].get_or_create(); + } + } + // traverse the user-supplied namespace + { + auto iter = ns.cbegin(); + auto end = ns.cend(); + while (iter != end) { + name = *iter++; + table = table[name].get_or_create(); + } + } + + table_fn(table); +} + +/** + * \brief Traverses the given namespace as a preparation for the registration of a type + * \tparam T Type of the class/enum to be registered + * \tparam ns_size Size of the namespace array + * \param db Instance of the DataBroker + * \param ns Array of ASCIIZ strings describing the namespace of the enum-type + * \param table_fn Callback accepting a SOL-table which reflects the given namespace + */ +template +void traverse_namespace(DataBroker& db, const std::vector& ns, + std::function table_fn) { + db.declare_type([&](sol::state_view view) { traverse_namespace_impl(view, ns, table_fn); }); +} + +/** + * \brief Registers a class under a given namespace + * \tparam T Type of the class to be registered + * \tparam Type of the SOL constructor-class + * \tparam ns_size Size of the namespace array + * \param db Instance of the DataBroker + * \param ns Array of ASCIIZ strings describing the namespace of the enum-type + * \param type_name ASCIIZ string describing the name of the class to be registered + */ +template +sol::usertype register_usertype(DataBroker& db, const std::vector& ns, + const char* type_name) { + sol::usertype result; + traverse_namespace( + db, ns, [&](sol::table& table) { result = table.new_usertype(type_name, CtorType()); }); + return result; +} + +/** + * \brief Registers an enum under a given namespace + * \tparam T Type of the enum to be registered + * \tparam Args Types of parameter 'args' + * \param db Instance of the DataBroker + * \param ns std::vector of ASCIIZ strings describing the namespace of the enum-type + * \param type_name ASCIIZ string describing the name of the enum to be registered + * \param args Pairs of ASCIIZ-String of one enum-value & the enum-value itself + */ +template +void register_enum(DataBroker& db, const std::vector& ns, const char* type_name, + Args&&... args) { + traverse_namespace( + db, ns, [&](sol::table& table) { table.new_enum(type_name, std::forward(args)...); }); +} + +/** + * \brief Registers a vector-type under a given namespace + * \tparam T Vector-type to be registered + * \tparam ns_size Size of the namespace array + * \tparam ints Index-sequence into the parameter member_names + * \param ns Array of ASCIIZ strings describing the namespace of the type + * \param type_name ASCIIZ string describing the name of the class to be registered + * \param member_names ASCIIZ string array describing the names of the vector properties + */ +template +void register_vector(DataBroker& db, const std::vector& ns, const char* type_name, + const char* member_names[], std::index_sequence) { + sol::usertype usertype = + register_usertype::Ctors>(db, ns, type_name); + + // Register properties x,y,z, w + ((usertype[member_names[ints]] = sol::property(&MatrixAccessor::template get, + &MatrixAccessor::template set)), + ...); + // Register operators + usertype[sol::meta_function::unary_minus] = [](const T& rhs) -> T { return -rhs; }; + usertype[sol::meta_function::addition] = [](const T& lhs, const T& rhs) -> T { + return lhs + rhs; + }; + usertype[sol::meta_function::subtraction] = [](const T& lhs, const T& rhs) -> T { + return lhs - rhs; + }; + + usertype[sol::meta_function::equal_to] = [](const T& lhs, const T& rhs) -> bool { + return lhs == rhs; + }; + + // Register methods + usertype["norm"] = [](const T& that) -> + typename MatrixTypeTraits::Scalar { return that.norm(); }; + usertype["dot"] = [](const T& that, const T& arg) -> + typename MatrixTypeTraits::Scalar { return that.dot(arg); }; + + // Vector3x can do a cross-product + if constexpr (3 == sizeof...(ints)) { + usertype["cross"] = [](const T& that, const T& arg) -> T { return that.cross(arg); }; + } +} + +template +void register_vector(DataBroker& db, const std::vector& ns, const char* class_name, + const char* member_names[]) { + register_vector(db, ns, class_name, member_names, + std::make_index_sequence::Rows>{}); +} + +std::vector namespace_cloe_object = {"cloe", "Object"}; + +void register_gaspedal_sensor(DataBroker& db, const std::string& vehicle, + std::function gaspedal_getter) { + { + using type = double; + auto signal = db.declare(fmt::format("vehicles.{}.sensor.gaspedal.position", vehicle)); + signal->set_getter(std::move(gaspedal_getter)); + auto documentation = fmt::format( + "Normalized gas pedal position for the '{}' vehicle

" + "Range [min-max]: [0-1]", + vehicle); + signal->add( + cloe::LuaAutocompletionTag::LuaDatatype::Number, + cloe::LuaAutocompletionTag::PhysicalQuantity::Dimensionless, + documentation); + signal->add(documentation); + } +} +void register_wheel_sensor(DataBroker& db, + const std::string& vehicle, + const std::string& wheel_name, + std::function + wheel_getter) { + { + using type = cloe::Wheel; + auto signal = + db.declare(fmt::format("vehicles.{}.sensor.wheels.{}", vehicle, wheel_name)); + signal->set_getter([wheel_getter]() -> const type& { return wheel_getter(); }); + auto documentation = fmt::format( + "Wheel sensor for the front-left wheel of the '{}' vehicle

" + "rotation: Rotational angle of wheel around y-axis in [rad]
" + "velocity: Compression of the spring in [m]
" + "spring_compression: Compression of the spring in [m]", + vehicle); + signal->add( + cloe::LuaAutocompletionTag::LuaDatatype::Class, + cloe::LuaAutocompletionTag::PhysicalQuantity::Dimensionless, + documentation); + signal->add(documentation); + } + { + using type = decltype(cloe::Wheel::rotation); + auto signal = + db.declare(fmt::format("vehicles.{}.sensor.wheels.{}.rotation", vehicle, wheel_name)); + signal->set_getter([wheel_getter]() -> const type& { return wheel_getter().rotation; }); + auto documentation = + fmt::format("Sensor for the rotation around y-axis of the {} wheel of the '{}' vehicle", + wheel_name, vehicle); + signal->add(cloe::LuaAutocompletionTag::LuaDatatype::Number, + cloe::LuaAutocompletionTag::PhysicalQuantity::Radian, + documentation); + signal->add(documentation); + } + { + using type = decltype(cloe::Wheel::velocity); + auto signal = + db.declare(fmt::format("vehicles.{}.sensor.wheels.{}.velocity", vehicle, wheel_name)); + signal->set_getter([wheel_getter]() -> const type& { return wheel_getter().velocity; }); + auto documentation = + fmt::format("Sensor for the translative velocity of the {} wheel of the '{}' vehicle", + wheel_name, vehicle); + signal->add(cloe::LuaAutocompletionTag::LuaDatatype::Number, + cloe::LuaAutocompletionTag::PhysicalQuantity::Velocity, + documentation); + signal->add(documentation); + } + { + using type = decltype(cloe::Wheel::spring_compression); + auto signal = db.declare( + fmt::format("vehicles.{}.sensor.wheels.{}.spring_compression", vehicle, wheel_name)); + signal->set_getter( + [wheel_getter]() -> const type& { return wheel_getter().spring_compression; }); + auto documentation = + fmt::format("Wheel sensor for spring compression of the {} wheel of the '{}' vehicle", + wheel_name, vehicle); + signal->add(cloe::LuaAutocompletionTag::LuaDatatype::Number, + cloe::LuaAutocompletionTag::PhysicalQuantity::Radian, + documentation); + signal->add(documentation); + } +} + +void register_lua_types(DataBroker& db) { + register_vector(db, namespace_eigen, "Vector2i", vector_names_xyzw); + register_vector(db, namespace_eigen, "Vector3i", vector_names_xyzw); + register_vector(db, namespace_eigen, "Vector4i", vector_names_xyzw); + + register_vector(db, namespace_eigen, "Vector2f", vector_names_xyzw); + register_vector(db, namespace_eigen, "Vector3f", vector_names_xyzw); + register_vector(db, namespace_eigen, "Vector4f", vector_names_xyzw); + + register_vector(db, namespace_eigen, "Vector2d", vector_names_xyzw); + register_vector(db, namespace_eigen, "Vector3d", vector_names_xyzw); + register_vector(db, namespace_eigen, "Vector4d", vector_names_xyzw); + + // clang-format off + register_enum<::cloe::Object::Type>( + db, namespace_cloe_object, "Type", + "Unknown", ::cloe::Object::Type::Unknown, + "Static", ::cloe::Object::Type::Static, + "Dynamic", ::cloe::Object::Type::Dynamic + ); + register_enum<::cloe::Object::Class>( + db, namespace_cloe_object, "Class", + "Unknown", ::cloe::Object::Class::Unknown, + "Pedestrian", ::cloe::Object::Class::Pedestrian, + "Bike", ::cloe::Object::Class::Bike, + "Motorbike", ::cloe::Object::Class::Motorbike, + "Car", ::cloe::Object::Class::Car, + "Truck", ::cloe::Object::Class::Truck, + "Trailer", ::cloe::Object::Class::Trailer + ); + // clang-format on +} + +} // namespace utility +} // namespace cloe diff --git a/models/src/cloe/utility/lua_types_test.cpp b/models/src/cloe/utility/lua_types_test.cpp new file mode 100644 index 000000000..e766d326e --- /dev/null +++ b/models/src/cloe/utility/lua_types_test.cpp @@ -0,0 +1,118 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file cloe/utility/lua_types_test.cpp + */ + +#include + +#include + +#include // for Eigen + +#include // for cloe::Object + +using DataBroker = cloe::DataBroker; + +TEST(lua_types_test, object) { + // Test Scenario: positive-test + // Test Case Description: Implement a vector3d signal and manipulate a member from Lua + // Test Steps: 1) Implement a signal + // 2) Stimulate the signal from Lua + // Prerequisite: - + // Test Data: - + // Expected Result: I) The value of the member changed + sol::state state; + sol::state_view view(state); + DataBroker db{view}; + + // Register all types + cloe::utility::register_lua_types(db); + + // 1) Implement a signal + auto gamma = db.implement("gamma"); + + // bind signals + db.bind_signal("gamma"); + db.bind("signals"); + // 2) Manipulate a member from Lua + const auto &code = R"( + local gamma = signals.gamma + gamma.type = cloe.types.cloe.Object.Type.Static; + gamma.classification = cloe.types.cloe.Object.Class.Pedestrian + signals.gamma = gamma + )"; + // run lua + state.open_libraries(sol::lib::base, sol::lib::package); + state.script(code); + // verify I + EXPECT_EQ(gamma->type, cloe::Object::Type::Static); + EXPECT_EQ(gamma->classification, cloe::Object::Class::Pedestrian); +} + +TEST(lua_types_test, vector3d) { + // Test Scenario: positive-test + // Test Case Description: Implement a vector3d signal and manipulate a member from Lua + // Test Steps: 1) Implement a signal + // 2) Stimulate the signal from Lua + // Prerequisite: - + // Test Data: - + // Expected Result: I) The value of the member changed + // II) The value-changed event was received + sol::state state; + sol::state_view view(state); + DataBroker db{view}; + + // Register all types + cloe::utility::register_lua_types(db); + + // 1) Implement a signal + auto gamma = db.implement("gamma"); + auto five = db.implement("five"); + + // bind signals + db.bind_signal("gamma"); + db.bind_signal("five"); + db.bind("signals"); + // 2) Manipulate a member from Lua + const auto &code = R"( + -- use default-constructor + local gamma = cloe.types.eigen.Vector3d.new() + gamma.x = -1 + gamma.y = 1.154431 + gamma.z = 3.1415926 + signals.gamma = gamma + + -- use value-constructor + local vec = cloe.types.eigen.Vector2i.new(3, 4) + + -- use copy-constructor + local vec2 = cloe.types.eigen.Vector2i.new(vec) + + -- use member-method + signals.five = vec2:norm() + )"; + // run lua + state.open_libraries(sol::lib::base, sol::lib::package); + state.script(code); + // verify I + EXPECT_EQ(gamma->operator[](0), -1); + EXPECT_EQ(gamma->operator[](1), 1.154431); + EXPECT_EQ(gamma->operator[](2), 3.1415926); + EXPECT_EQ(five, 5); +} diff --git a/plugins/basic/src/basic.cpp b/plugins/basic/src/basic.cpp index 27baf6648..d208218a9 100644 --- a/plugins/basic/src/basic.cpp +++ b/plugins/basic/src/basic.cpp @@ -38,6 +38,7 @@ #include // for ObjectSensor #include // for EgoSensor, EgoSensorCanon #include // for Controller, Json, etc. +#include // for DataBroker #include // for ToJson, FromConf #include // for CloeComponent #include // for EXPORT_CLOE_PLUGIN @@ -396,6 +397,47 @@ class BasicController : public Controller { } void enroll(Registrar& r) override { + auto& db = r.data_broker(); + if (this->veh_) { + auto& vehicle = this->veh_->name(); + { + std::string name1 = fmt::format("vehicles.{}.{}.acc", vehicle, name()); + auto acc_signal = db.declare(name1); + acc_signal->set_getter( + [this]() -> const cloe::controller::basic::AccConfiguration& { + return this->acc_.config; + }); + acc_signal->set_setter( + [this](const cloe::controller::basic::AccConfiguration& value) { + this->acc_.config = value; + }); + } + { + std::string name1 = fmt::format("vehicles.{}.{}.aeb", vehicle, name()); + auto aeb_signal = db.declare(name1); + aeb_signal->set_getter( + [this]() -> const cloe::controller::basic::AebConfiguration& { + return this->aeb_.config; + }); + aeb_signal->set_setter( + [this](const cloe::controller::basic::AebConfiguration& value) { + this->aeb_.config = value; + }); + } + { + std::string name1 = fmt::format("vehicles.{}.{}.lka", vehicle, name()); + auto lka_signal = db.declare(name1); + lka_signal->set_getter( + [this]() -> const cloe::controller::basic::LkaConfiguration& { + return this->lka_.config; + }); + lka_signal->set_setter( + [this](const cloe::controller::basic::LkaConfiguration& value) { + this->lka_.config = value; + }); + } + } + auto lua = r.register_lua_table(); { diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index 771837e16..9e46c5708 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -46,6 +46,7 @@ add_library(cloe-runtime SHARED src/cloe/utility/std_extensions.cpp src/cloe/utility/uid_tracker.cpp src/cloe/utility/xdg.cpp + src/cloe/data_broker.cpp # For IDE integration ${cloe-runtime_PUBLIC_HEADERS} @@ -96,11 +97,21 @@ if(BUILD_TESTING) src/cloe/version_test.cpp src/cloe/utility/statistics_test.cpp src/cloe/utility/uid_tracker_test.cpp + src/cloe/data_broker_test.cpp ) set_target_properties(test-cloe PROPERTIES CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON ) + target_compile_options(test-cloe + PUBLIC + -g + -O0 + ) + target_compile_definitions(test-cloe + PUBLIC + SOL_ALL_SAFETIES_ON=1 + ) target_link_libraries(test-cloe PRIVATE GTest::gtest diff --git a/runtime/include/cloe/cloe_fwd.hpp b/runtime/include/cloe/cloe_fwd.hpp index b04e1131f..fa054f3ec 100644 --- a/runtime/include/cloe/cloe_fwd.hpp +++ b/runtime/include/cloe/cloe_fwd.hpp @@ -43,6 +43,13 @@ class ConcludedError; using Logger = std::shared_ptr; using LogLevel = spdlog::level::level_enum; +// from data_broker.hpp +template +class BasicContainer; +class Signal; +using SignalPtr = std::shared_ptr; +class DataBroker; + // from entity.hpp class Entity; diff --git a/runtime/include/cloe/data_broker.hpp b/runtime/include/cloe/data_broker.hpp new file mode 100644 index 000000000..293cdf8d8 --- /dev/null +++ b/runtime/include/cloe/data_broker.hpp @@ -0,0 +1,2063 @@ +/* + * Copyright 2023 Robert Bosch GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +/** + * \file cloe/data_broker.hpp + * + * Classes: + * DataBroker: + * A central registry for (type erased) signals. + * + * Signal: + * A type-erased abstraction of a signal. + * + * BasicContainer + * An optional container storing the value of a signal and abstracting the interaction with Signal. + * + * Background: + * Real world simulations utilize manifolds of signals (variables) and coresponding types. + * Still the very basic problem around a single signal (variable) boils down to CRUD + * - declaring a signal (Create) + * - reading its value / receiving value-changed events (Read) + * - writing its value / triggering value-changed events (Update) + * - NOT NEEDED: undeclaring a signal (Delete) + * + * Concept: + * Abstract CRU(D) operations by registering signals in a uniform way by their name. + */ + +#pragma once +#ifndef CLOE_DATA_BROKER_HPP_ +#define CLOE_DATA_BROKER_HPP_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace cloe { + +namespace databroker { + +namespace detail { + +/** + * Detects the presence of the to_lua function (based on ADL) + */ +template +struct has_to_lua : std::false_type {}; +/** + * Detects the presence of the to_lua function (based on ADL) + */ +template +struct has_to_lua< + T, std::void_t(), std::declval()))>> + : std::true_type {}; +/** + * Detects the presence of the to_lua function (based on ADL) + */ +template +constexpr bool has_to_lua_v = has_to_lua::value; + +/** + * Invokes to_lua procedure, if detecting its presence + */ +template +void to_lua(sol::state_view lua) { + if constexpr (has_to_lua_v) { + to_lua(lua, static_cast(nullptr)); + } else { + // nop + } +} + +} // namespace detail + +/** + * Maps a predicate to type T or type int. + * \returns T, if condition is true + * int, otherwise + */ +template +struct type_t_or_int_if { + using type = T; +}; + +template +struct type_t_or_int_if { + using type = int; +}; + +/** + * Predicate which determines whether a type is incompatible with the data + * broker. + */ +template +constexpr bool is_incompatible_type_v = std::is_void_v || std::is_reference_v; + +/** + * Determines a datatype which is compatible with the data broker, derived from + * the template type. + * + * \returns T, if T is compatible with the data broker; + * int, otherwise + * \note Purpose is to suppress irrelevant compiler errors, in case of + * incompatible data types + */ +template +using compatible_base_t = typename type_t_or_int_if>::type; + +/** + * Argument-/Return-Type of signal related functions + */ +template +using signal_type_cref_t = databroker::compatible_base_t const&; + +/** + * Type of event function, which is called when the value of a signal changed + */ +template +using on_value_changed_callback_t = std::function)>; + +/** + * Abstract event implementation. + * + * \tparam TArgs Arguments of the event handler function + * \note: Design-Goals: + * - Design-#1: Unsubscribing from an event is not intended + */ +template +class Event { + public: + using EventHandler = std::function; + + private: + std::vector eventhandlers_{}; + + public: + /** + * Add an event handler to this event. + */ + void add(EventHandler handler) { eventhandlers_.emplace_back(std::move(handler)); } + + /** + * Return the number of event handlers subscribed to this event. + */ + [[nodiscard]] std::size_t count() const { return eventhandlers_.size(); } + + /** + * Raise this event. + * \param args Parameters of this event + */ + void raise(TArgs&&... args) const { + for (const auto& eventhandler : eventhandlers_) { + try { + eventhandler(std::forward(args)...); + } catch (...) { + throw; + } + } + } +}; + +} // namespace databroker + +// Forward declarations: +template +class BasicContainer; +class Signal; +class DataBroker; + +/** + * Assert that the type-argument is compatible with the data broker. + */ +template +constexpr void assert_static_type() { + static_assert(!static_cast(databroker::is_incompatible_type_v), + "Incompatible-Datatype-Error.\n" + "\n" + "Please find the offending LOC in above line 'require from here'.\n" + "\n" + "Explanation/Reasoning:\n" + "- references & void are fundamentally incompatible"); +} + +using SignalPtr = std::shared_ptr; + +/** + * Function which integrates a specific datum into the Lua-VM + */ +using lua_signal_adapter_t = + std::function; + +/** + * Function which declares a specific datatype to the Lua-VM + */ +using lua_signal_declarator_t = std::function; + +template +using Container = BasicContainer>; + +template +class BasicContainer { + public: + using value_type = databroker::compatible_base_t; + + private: + /** + * Access-token for regulating API access (public -> private) + */ + struct access_token { + explicit access_token(int /*unused*/){}; + }; + + value_type value_{}; + databroker::on_value_changed_callback_t on_value_changed_{}; + Signal* signal_{}; + + public: + BasicContainer() = default; + BasicContainer(Signal* signal, + databroker::on_value_changed_callback_t + on_value_changed, + access_token /*unused*/) + : value_(), on_value_changed_(std::move(on_value_changed)), signal_(signal) { + update_accessor_functions(this); + } + BasicContainer(const BasicContainer&) = delete; + BasicContainer(BasicContainer&& source) { *this = std::move(source); } + + ~BasicContainer() { update_accessor_functions(nullptr); } + BasicContainer& operator=(const BasicContainer&) = delete; + BasicContainer& operator=(BasicContainer&& rhs) { + value_ = std::move(rhs.value_); + on_value_changed_ = std::move(rhs.on_value_changed_); + signal_ = std::move(rhs.signal_); + update_accessor_functions(this); + + rhs.signal_ = nullptr; + + return *this; + } + BasicContainer& operator=(databroker::signal_type_cref_t value) { + value_ = value; + if (on_value_changed_) { + on_value_changed_(value_); + } + return *this; + } + + const value_type& value() const { return value_; } + value_type& value() { return value_; } + void set_value(databroker::signal_type_cref_t value) { *this = value; } + + [[nodiscard]] bool has_subscriber() const; + [[nodiscard]] std::size_t subscriber_count() const; + + // mimic std::optional + constexpr const value_type* operator->() const noexcept { return &value_; } + constexpr value_type* operator->() noexcept { return &value_; } + constexpr const value_type& operator*() const noexcept { return value_; } + constexpr value_type& operator*() noexcept { return value_; } + + private: + void update_accessor_functions(BasicContainer* container); + + friend class Signal; +}; + +// Compare two BasicContainer + +template +constexpr bool operator==(const BasicContainer& lhs, const BasicContainer& rhs) { + return *lhs == *rhs; +} +template +constexpr bool operator!=(const BasicContainer& lhs, const BasicContainer& rhs) { + return *lhs != *rhs; +} +template +constexpr bool operator<(const BasicContainer& lhs, const BasicContainer& rhs) { + return *lhs < *rhs; +} +template +constexpr bool operator<=(const BasicContainer& lhs, const BasicContainer& rhs) { + return *lhs <= *rhs; +} +template +constexpr bool operator>(const BasicContainer& lhs, const BasicContainer& rhs) { + return *lhs > *rhs; +} +template +constexpr bool operator>=(const BasicContainer& lhs, const BasicContainer& rhs) { + return *lhs >= *rhs; +} + +// Compare BasicContainer with a value + +template +constexpr bool operator==(const BasicContainer& lhs, const U& rhs) { + return *lhs == rhs; +} +template +constexpr bool operator==(const T& lhs, const BasicContainer& rhs) { + return lhs == *rhs; +} +template +constexpr bool operator!=(const BasicContainer& lhs, const U& rhs) { + return *lhs != rhs; +} +template +constexpr bool operator!=(const T& lhs, const BasicContainer& rhs) { + return lhs != *rhs; +} +template +constexpr bool operator<(const BasicContainer& lhs, const U& rhs) { + return *lhs < rhs; +} +template +constexpr bool operator<(const T& lhs, const BasicContainer& rhs) { + return lhs < *rhs; +} +template +constexpr bool operator<=(const BasicContainer& lhs, const U& rhs) { + return *lhs <= rhs; +} +template +constexpr bool operator<=(const T& lhs, const BasicContainer& rhs) { + return lhs <= *rhs; +} +template +constexpr bool operator>(const BasicContainer& lhs, const U& rhs) { + return *lhs > rhs; +} +template +constexpr bool operator>(const T& lhs, const BasicContainer& rhs) { + return lhs > *rhs; +} +template +constexpr bool operator>=(const BasicContainer& lhs, const U& rhs) { + return *lhs >= rhs; +} +template +constexpr bool operator>=(const T& lhs, const BasicContainer& rhs) { + return lhs >= *rhs; +} + +/** + * MetaInformation collects abstract metainformation + * + * \note: Design-Goals: + * - Design-#1: Key-Value (cardinality: 0-1:1). The key defines the value-type. + * - Design-#2: Type-erasing techniques shall not eradicate type-safety nor put additional validation steps onto the users. + * \note: Implementation-Notes: + * - Implementation-#1: + * Implementations like "all-is-byte-arrays" or "JSON" were considered and disregarded. + * E.g. JSON is a serialization format. Using JSON-Schema Design-#2 would be covered. + * This would imply a) a dependency on multiple levels & b) significant runtime efforts. + * Shooting sparrows with canons. Pure C++ can do the job in <50 LOC + some for porcellain. + */ +class MetaInformation { + private: + using metainformation_map_t = std::unordered_map; + metainformation_map_t metainformations_; + + public: + /** + * Tag which identifies the metainformation and carries the type information of the actual metainformation + */ + template + struct Tag { + public: + using tag_type = T; + }; + + public: + MetaInformation() = default; + virtual ~MetaInformation() = default; + + private: + template + constexpr void assert_static_type() { + // prevent usage of references + static_assert(std::is_reference_v == false, + "References are unsupported."); + } + + public: + template + /** + * Removes an metainformation + * \tparam Tag of the metainformation to be removed + */ + void remove() { + auto tindex = std::type_index(typeid(T)); + auto iter = metainformations_.find(tindex); + if (iter != metainformations_.end()) { + metainformations_.erase(iter); + } + } + /** + * Adds a metainformation + * \tparam T Type of the metainformation-tag + * \param metainformation_any Actual metainformation to be added + */ + template + void add_any(std::any metainformation_any) { + auto tindex = std::type_index(typeid(T)); + auto iter = metainformations_.find(tindex); + if (iter != metainformations_.end()) { + } + metainformations_[tindex] = std::move(metainformation_any); + } + /** + * Returns a metainformation + * \tparam T Type of the metainformation-tag + * \returns std:any* if the metainformation is present, nullptr otherwise + */ + template + const std::any* get_any() const { + auto tindex = std::type_index(typeid(T)); + auto iter = metainformations_.find(tindex); + if (iter != metainformations_.end()) { + const std::any& metainformation_any = iter->second; + return &metainformation_any; + } else { + return nullptr; + } + } + /** + * Returns a metainformation + * \tparam T Type of the metainformation-tag + * \returns Annotation of type T::tag_type* if the metainformation is present, nullptr otherwise + */ + template + std::enable_if_t, const typename T::tag_type>* get() const { + const std::any* metainformation_any = get_any(); + return (metainformation_any != nullptr) + ? std::any_cast(metainformation_any) + : nullptr; + } + /** + * Returns a metainformation + * \tparam T Type of the metainformation-tag + * \returns true if the metainformation is present, false otherwise + */ + template + std::enable_if_t, bool> get() const { + const std::any* metainformation_any = get_any(); + return (metainformation_any != nullptr); + } + + /** + * Adds a metainformation + * \tparam T Type of the metainformation-tag + * \param metainformation Actual metainformation to be added + * \note This overload is enabled only when the effective tag_type is move constructible + */ + template + // clang-format off + std::enable_if_t< + !std::is_void_v + && !std::is_base_of_v< Tag< T >, T > + && std::is_move_constructible_v + > + // clang-format on + add(typename T::tag_type metainformation) { + assert_static_type(); + + std::any metainformation_any = std::move(metainformation); + add_any(std::move(metainformation_any)); + } + /** + * Adds a metainformation + * \tparam T Type of the metainformation-tag + * \param metainformation Actual metainformation to be added + * \note This overload is enabled only when the effective tag_type is copy constructible + */ + template + // clang-format off + std::enable_if_t< + !std::is_void_v + && !std::is_base_of_v< Tag< T >, T > + && std::is_copy_constructible_v + && !std::is_move_constructible_v + > + // clang-format on + add(const typename T::tag_type& metainformation) { + assert_static_type(); + + std::any metainformation_any = metainformation; + add_any(std::move(metainformation_any)); + } + /** + * Adds a metainformation + * \tparam T Type of the metainformation-tag + * \note This overload is enabled only when the effective tag_type is void + */ + template + // clang-format off + std::enable_if_t< + std::is_void_v + > + // clang-format on + add() { + std::any metainformation_any; + add_any(std::move(metainformation_any)); + } + /** + * Adds a metainformation + * \tparam T Type of the metainformation-tag + * \param metainformation Actual metainformation to be added + * \note This overload is enabled only when the tag inherits Tag<> and is the effective tag_type + * + * Usage Note: + * + * ``` + * struct my_tag : cloe::MetaInformation::Tag { + * ... + * }; + * ... + * my_tag tag; + * ... + * add(std::move(tag)); + * ``` + */ + template + // clang-format off + std::enable_if_t< + std::is_base_of_v< Tag< T >, T > + && std::is_same_v + > + // clang-format on + add(T metainformation) { + assert_static_type(); + + std::any metainformation_any = std::move(metainformation); + add_any(std::move(metainformation_any)); + } + /** + * Adds a metainformation constructed from the supplied parameters + * \tparam T Type of the metainformation-tag + * \tparam TArgs... Type of the metainformation c'tor arguments + * \param args Arguments for the c'tor of the metainformation + * \note This overload is enabled only when the tag inherits Tag<> and is the effective tag_type + * + * Usage Note: + * + * ``` + * add(arg1, arg2, ...); + * ... + * }; + * ``` + */ + template + // clang-format off + std::enable_if_t< + std::is_base_of_v< Tag< T >, T > + && std::is_same_v + > + // clang-format on + add(TArgs... args) { + assert_static_type(); + T metainformation(std::forward(args)...); + std::any metainformation_any = std::move(metainformation); + add_any(std::move(metainformation_any)); + } +}; + +struct SignalDocumentation : MetaInformation::Tag { + /** + * Documentation text + * \note Use
to achieve a linebreak + */ + std::string text; + + SignalDocumentation(std::string text_) : text{std::move(text_)} {} + + friend const std::string& to_string(const SignalDocumentation& doc) { return doc.text; } +}; + +/** + * Signal-Metainformation for generation of Lua documentation + */ +struct LuaAutocompletionTag : MetaInformation::Tag { +/** + * X-Macro: enum definition & enum-to-string conversion + */ +#define LUADATATYPE_LIST \ + X(Class, 0) \ + X(Number, 1) \ + X(String, 2) + + enum class LuaDatatype { +#define X(name, value) name = value, + LUADATATYPE_LIST +#undef X + }; + + friend std::string to_string(LuaDatatype type) { + switch (type) { +#define X(name, value) \ + case LuaDatatype::name: \ + return #name; + LUADATATYPE_LIST +#undef X + default: + return {}; + } + } +#undef LUADATATYPE_LIST + +/** + * X-Macro: enum definition & enum-to-string conversion + */ +#define PHYSICALQUANTITIES_LIST \ + X(Dimensionless, "[]") \ + X(Radian, "[rad]") \ + X(Length, "[m]") \ + X(Time, "[s]") \ + X(Mass, "[kg]") \ + X(Temperature, "[K]") \ + X(ElectricCurrent, "[A]") \ + X(Velocity, "[m/s]") \ + X(Acceleration, "[m/s^2]") \ + X(Jerk, "[m/s^3]") \ + X(Jounce, "[m/s^4]") \ + X(Crackle, "[m/s^5]") + + enum class PhysicalQuantity { +#define X(name, value) name, + PHYSICALQUANTITIES_LIST +#undef X + }; + + friend std::string to_string(PhysicalQuantity type) { + switch (type) { +#define X(name, value) \ + case PhysicalQuantity::name: \ + return #value; + PHYSICALQUANTITIES_LIST +#undef X + default: + return {}; + } + } +#undef PHYSICALQUANTITIES_LIST + + /** + * Lua datatype of the signal + */ + LuaDatatype datatype; + /** + * Lua datatype of the signal + */ + PhysicalQuantity unit; + /** + * Documentation text + * \note Use
to achieve a linebreak + */ + std::string text; + + LuaAutocompletionTag(LuaDatatype datatype_, PhysicalQuantity unit_, std::string text_) + : datatype{std::move(datatype_)}, unit{std::move(unit_)}, text{std::move(text_)} {} +}; + +/** + * Signal represents the properties of a signal at runtime + * + * \note: Design-Goals: + * - Design-#1: The class shall expose a uniform interface (via type-erasure) + * - Design-#2: The class shall be constructible via external helpers + * - Design-#3: CppCoreGuidelines CP-1 does not apply. See also: https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rconc-multi + * - Design-#4: Templates shall be instantiated explicitly + * \note: Implementation-Notes: + * - Implementation-#1: Type specific aspects are implemented in templates (Design-#1) + * - Implementation-#2: Objects are created via a factory-method (to prevent instances of incomplete initialization) + * - Implementation-#3: An access-token is used to have a public c'tor which in fact shall be inaccesible (Design-#2). + */ +class Signal { + private: + /** + * Access-token for regulating API access (public -> private). + */ + struct access_token { + explicit access_token(int /*unused*/){}; + }; + + public: + /** + * Getter-function types + * + * Note: + * When implementing this function with lamdas (e.g. set_getter<>() ) + * take care to explicitly define the return-type. + * + * Example: + * []() -> const int& { return value; } + * + * Undefined behaviour: + * []() { return value; } + * ^^^ No return type specified + */ + template + using typed_get_value_function_t = std::function()>; + using type_erased_get_value_function_t = std::any; + + /** + * Setter-function types + */ + template + using typed_set_value_function_t = std::function)>; + using type_erased_set_value_function_t = std::any; + /** + * Event types + */ + template + using typed_value_changed_event_t = databroker::Event>; + using type_erased_value_changed_event_t = std::any; + /** + * Event trigger-function types + */ + template + using typed_on_value_change_event_function_t = + std::function)>; + using type_erased_on_value_change_event_function_t = std::any; + + private: + /// Name(s) of the signal + std::vector names_{}; + /// std::type_info of the signal + const std::type_info* type_{nullptr}; + /// getter-function + type_erased_get_value_function_t get_value_{}; + /// setter-function + type_erased_set_value_function_t set_value_{}; + /// Event which gets raised when the signal value changes + type_erased_value_changed_event_t value_changed_event_{}; + /// Triggers the value-changed event + type_erased_on_value_change_event_function_t on_value_changed_{}; + /// std::function returning the count of event subscribers + std::function subscriber_count_{}; + /// metadata accompanying the signal + MetaInformation metainformations_; + + /** + * Private default c'tor + * \note: Implementation-#2: The class shall be created only via a factory-method + */ + Signal() = default; + + public: + /** + * Public c'tor, accessible only via private access-token + * \note: Design-#1: The class shall be constructible via external helpers + */ + explicit Signal(access_token /*unused*/) : Signal() {} + Signal(const Signal&) = delete; + Signal(Signal&&) = default; + virtual ~Signal() = default; + Signal& operator=(const Signal&) = delete; + Signal& operator=(Signal&&) = default; + + /** + * Return the type info for the signal. + */ + constexpr const std::type_info* type() const { return type_; } + + private: + /** + * Validate that the stored signal type matches the template type. + * + * \tparam T Expected type of the signal + */ + template + constexpr void assert_dynamic_type() const { + const std::type_info* static_type = &typeid(T); + const std::type_info* dynamic_type = type(); + if ((dynamic_type == nullptr) || (!(*dynamic_type == *static_type))) { + throw std::logic_error( + fmt::format("mismatch between dynamic-/actual-type and static-/requested-type; " + "signal type: {}, requested type: {}", + dynamic_type != nullptr ? dynamic_type->name() : "", static_type->name())); + } + } + + public: + /** + * Return the getter function of the signal. + * + * \tparam T Type of the signal + * \return Getter function of the signal + * + * Example: + * ``` + * Signal s = ...; + * const typed_get_value_function_t* f = s.getter(); + * int v = (*f)(); + * ``` + */ + template + const typed_get_value_function_t* getter() const { + assert_static_type(); + assert_dynamic_type(); + + const typed_get_value_function_t* get_value_fn = + std::any_cast>(&get_value_); + return get_value_fn; + } + + /** + * Set the getter function of the signal. + * + * \tparam T Type of the signal + * \param get_value Getter function of the signal + * + * Example: + * ``` + * Signal s = ...; + * s.set_getter([&]() -> const type & { return value; }); + * ``` + * Usage Note: When using lambdas, the explicit return-type definition is important! + * + * Undefined behaviour: + * ``` + * Signal s = ...; + * s.set_getter([&]() { return value; }); + * ^^^ No return type specified, type conversion rules apply + * ``` + */ + template + void set_getter(typed_get_value_function_t get_value_fn) { + assert_static_type(); + assert_dynamic_type(); + + get_value_ = std::move(get_value_fn); + } + + /** + * Return the current value of the signal. + * + * \tparam T Type of the signal + * \return databroker::signal_type_cref_t, Current value of the signal + * \note databroker::compatible_base_t == T, if the method compiles + */ + template + databroker::signal_type_cref_t value() const { + assert_static_type(); + assert_dynamic_type(); + using compatible_type = databroker::compatible_base_t; + + const typed_get_value_function_t* getter_fn = getter(); + if (getter_fn && (*getter_fn)) { + databroker::signal_type_cref_t value = getter_fn->operator()(); + return value; + } + throw std::logic_error( + fmt::format("unable to get value for signal without getter-function: {}", names_.front())); + } + + /** + * Return the getter function of the signal. + * + * \tparam T Type of the signal + * \return const typed_set_value_function_t*, Getter function of the signal + */ + template + const typed_set_value_function_t* setter() const { + assert_static_type(); + assert_dynamic_type(); + + const typed_set_value_function_t* set_value_fn = + std::any_cast>(&set_value_); + return set_value_fn; + } + + /** + * Set the setter function of the signal. + * + * \tparam T Type of the signal + * \param set_value Getter function of the signal + */ + template + void set_setter(typed_set_value_function_t set_value_fn) { + assert_static_type(); + assert_dynamic_type(); + + set_value_ = std::move(set_value_fn); + } + + /** + * Set the value of the signal. + * + * \tparam T Type of the signal + * \param value Value of the signal + */ + template + void set_value(databroker::signal_type_cref_t value) const { + assert_static_type(); + assert_dynamic_type(); + using compatible_type = databroker::compatible_base_t; + + const typed_set_value_function_t* setter_fn = setter(); + if (setter_fn && *setter_fn) { + setter_fn->operator()(value); + return; + } + throw std::logic_error( + fmt::format("unable to set value for signal without setter-function: {}", names_.front())); + } + + /** + * Return the trigger function for the value_changed event. + * + * \tparam T Type of the signal + * \return Trigger function for raising the value_changed event + */ + template + const typed_on_value_change_event_function_t& trigger() const { + assert_static_type(); + assert_dynamic_type(); + + const typed_on_value_change_event_function_t* trigger_fn = + std::any_cast>(&on_value_changed_); + assert(trigger_fn); + return *trigger_fn; + } + + /** + * Tags a signal with metadata + * + * \tparam T Type of the tag + * \param metadata Metadata used to tag the signal + * \note This is the overload for non-void tags (T2 != void) + */ + template + std::enable_if_t> add(typename T::tag_type metadata) { + static_assert(std::is_reference_v == false); + metainformations_.add(metadata); + } + /** + * Tags a signal with metadata constructed from parameters + * + * \tparam T Type of the tag + * \tparam TArgs Type of the metadata c'tor parameters + * \param args Metadata c'tor arguments + * \note This is the overload for non-void tags (T2 != void) + */ + template + std::enable_if_t> add(TArgs&&... args) { + static_assert(std::is_reference_v == false); + metainformations_.add(std::forward(args)...); + } + + /** + * Tags a signal with a void tag + * + * \tparam T Type of the tag + */ + template + // clang-format off + std::enable_if_t< + std::is_void_v + > + // clang-format on + add() { + metainformations_.add(); + } + /** + * Get a tag of the signal + * + * \tparam T Type of the tag + * \returns const typename T::tag_type* pointing to the tag-value (or nullptr), if typename T::tag_type != void + * \returns bool expressing the presence of the tag, if typename T::tag_type == void + */ + template + auto metadata() -> decltype(metainformations_.get()) { + return metainformations_.get(); + } + /** + * Get all tags of the signal + */ + const MetaInformation& metadatas() const { return metainformations_; } + + private: + /** + * Unpack the event to Event and provide it to the caller. + * + * \tparam T Type of the event + * \param callback Caller function which accepts the unpacked event + * \note: In case type T and actual type do not match an exception is thrown + */ + template + void subscribe_impl( + const std::function>&)>& callback) { + typed_value_changed_event_t* value_changed_event = + std::any_cast>(&value_changed_event_); + if (callback) { + assert(value_changed_event); + callback(*value_changed_event); + } + } + + public: + /** + * Subscribe to value-changed events. + * + * \tparam T Type of the signal + * \param callback event-function which will be called when the value changed + */ + template + void subscribe(databroker::on_value_changed_callback_t callback) { + assert_static_type(); + assert_dynamic_type(); + + subscribe_impl( + [callback = std::move(callback)](typed_value_changed_event_t& value_changed_event) { + value_changed_event.add(std::move(callback)); + }); + } + + /** + * Return the count of subscribers to the value_changed event. + * + * \return size_t Count of subscribers to the value_changed event + */ + std::size_t subscriber_count() const { return subscriber_count_(); } + + /** + * Indicate whether the value_changed event has subscribers. + * + * \return bool True if the value_changed event has subscribers, false otherwise + */ + bool has_subscriber() const { return subscriber_count() > 0; } + + /** + * Return the list of names assigned to the signal. + * + * \return List of names assigned to the signal + */ + const std::vector& names() const { return names_; } + + /** + * Return the first assigned name of the signal. + * + * \return First name of the signal + */ + const std::string& name() const { + if (names_.empty()) { + throw std::logic_error("signal does not have a name"); + } + return names_.front(); + } + + /** + * Return the first assigned name of the signal. + * + * \return First name of the signal + */ + std::string name_or(std::string def) const { + if (names_.empty()) { + return def; + } + return names_.front(); + } + + /** + * Add a name of the signal. + * + * \param name Name of the signal + */ + void add_name(std::string_view name) { names_.emplace_back(name); } + + private: + /** + * Factory for Signal. + * + * \tparam T Type of the signal + * \param name Name of the signal + * \return owning unique pointer to Signal + * \note: Design-Note #1 reasons the existance of this factory + */ + template + static std::unique_ptr make() { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + auto signal = std::make_unique(access_token(0)); + signal->initialize(); + return signal; + } + + /** + * Create the container for a signal. + * + * \tparam T Type of the signal + * \return BasicContainer for the signal + * \note Design-#4: Templates shall be instantiated explicitly + */ + template + BasicContainer create_container() { + assert_static_type(); + + typed_value_changed_event_t* value_changed_event = + std::any_cast>(&value_changed_event_); + // Create container + BasicContainer result = BasicContainer( + this, + [value_changed_event](databroker::signal_type_cref_t value) { + value_changed_event->raise(std::move(value)); + }, + typename BasicContainer::access_token(0)); + return result; + } + + public: + private: + template + void initialize() { + assert_static_type(); + + type_ = &typeid(T); + // Create event + value_changed_event_ = typed_value_changed_event_t(); + typed_value_changed_event_t* value_changed_event = + &std::any_cast&>(value_changed_event_); + // Create event-trigger + typed_on_value_change_event_function_t on_value_changed = + [value_changed_event](databroker::signal_type_cref_t value) { + value_changed_event->raise(std::move(value)); + }; + on_value_changed_ = on_value_changed; + // Create subscriber_count function + subscriber_count_ = [value_changed_event]() { return value_changed_event->count(); }; + } + + template + friend class BasicContainer; + friend class DataBroker; +}; + +template +void BasicContainer::update_accessor_functions(BasicContainer* container) { + if (signal_ != nullptr) { + // Create getter-function + if (container) { + signal_->template set_getter( + [container]() -> databroker::signal_type_cref_t { return container->value(); }); + signal_->template set_setter( + [container](databroker::signal_type_cref_t value) { container->set_value(value); }); + } else { + signal_->template set_getter(Signal::typed_get_value_function_t()); + signal_->template set_setter(Signal::typed_set_value_function_t()); + } + } +} + +template +bool BasicContainer::has_subscriber() const { + return signal_ != nullptr && signal_->has_subscriber(); +} + +template +std::size_t BasicContainer::subscriber_count() const { + return signal_ != nullptr ? signal_->subscriber_count() : 0; +} + +/** + * TypedSignal decorates Signal with a specific datatype. + */ +template +class TypedSignal { + private: + SignalPtr signal_; + + public: + TypedSignal(SignalPtr signal) : signal_{signal} {} + ~TypedSignal() = default; + + operator SignalPtr&() { return signal_; } + operator const SignalPtr&() const { return signal_; } + + const T& value() const { return signal_->template value(); } + + template + void set_setter(TSetter setter) { + signal_->template set_setter(std::move(setter)); + } +}; + +/** + * Registry for type-erased signals. + */ +class DataBroker { + public: + using SignalContainer = std::map>; + + private: + SignalContainer signals_{}; + std::unordered_map bindings_{}; + std::unordered_map lua_declared_types_{}; + + public: + DataBroker() = default; + explicit DataBroker(const sol::state_view& lua) : lua_(lua), signals_object_(*lua_) {} + DataBroker(const DataBroker&) = delete; + DataBroker(DataBroker&&) = delete; + ~DataBroker() = default; + DataBroker& operator=(const DataBroker&) = delete; + DataBroker& operator=(DataBroker&&) = delete; + + private: + /** + * Dynamic class which embedds all signals in shape of properties into the Lua-VM + */ + class SignalsObject { + private: + /** + * Lua-Getter Function (C++ -> Lua) + */ + using lua_getter_fn = std::function; + /** + * Lua-Setter Function (Lua -> C++) + */ + using lua_setter_fn = std::function; + /** + * Lua accessors (getter/setter) + */ + struct lua_accessor { + lua_getter_fn getter; + lua_setter_fn setter; + }; + /** + * Signals map (name -> accessors) + */ + using accessors = std::unordered_map; + /** + * Mapped signals + */ + accessors accessors_; + /** + * Lua usertype, declares this class towards Lua + */ + sol::usertype signals_table_; + + public: + SignalsObject(sol::state_view& lua) + : accessors_() + , signals_table_(lua.new_usertype( + "SignalsObject", sol::meta_function::new_index, &SignalsObject::set_property_lua, + sol::meta_function::index, &SignalsObject::get_property_lua)) {} + + /** + * \brief Getter function for dynamic Lua properties + * \param name Accessed name on Lua level + * \param s Current Lua-state + */ + sol::object get_property_lua(const char* name, sol::this_state s) { + auto iter = accessors_.find(name); + if (iter != accessors_.end()) { + auto result = iter->second.getter(s); + return result; + } else { + throw std::out_of_range( + fmt::format("Failure to access signal '{}' from Lua since it is not bound.", name)); + } + } + /** + * \brief Setter function for dynamic Lua properties + * \param name Accessed name on Lua level + * \param object Lua-Object assigned to the property + */ + void set_property_lua(const char* name, sol::stack_object object) { + auto iter = accessors_.find(name); + if (iter != accessors_.end()) { + iter->second.setter(object); + } else { + throw std::out_of_range( + fmt::format("Failure to access signal '{}' from Lua since it is not bound.", name)); + } + } + /** + * Factory which produces the gluecode to r/w Lua properties + */ + template + struct LuaAccessorFactory { + using type = T; + using value_type = T; + static lua_accessor make(const SignalPtr& signal) { + lua_accessor result; + result.getter = [signal](sol::this_state& state) -> sol::object { + const value_type& value = signal->value(); + return sol::make_object(state, value); + }; + result.setter = [signal](sol::stack_object& obj) -> void { + T value = obj.as(); + signal->set_value(value); + }; + return result; + } + }; + /** + * Factory which produces the gluecode to r/w Lua properties + * \note Specialization for std::optional + */ + template + struct LuaAccessorFactory> { + using type = std::optional; + using value_type = T; + static lua_accessor make(const SignalPtr& signal) { + lua_accessor result; + result.getter = [signal](sol::this_state& state) -> sol::object { + const type& value = signal->value(); + if (value) { + return sol::make_object(state, value.value()); + } else { + return sol::make_object(state, sol::lua_nil); + } + }; + result.setter = [signal](sol::stack_object& obj) -> void { + type value; + if (obj != sol::lua_nil) { + value = obj.as(); + } + signal->set_value(value); + }; + return result; + } + }; + + /** + * \brief Binds one signal to Lua + * \param signal signal to be bound to Lua + * \param lua_name name of the signal in Lua + */ + template + void bind(const SignalPtr& signal, std::string_view lua_name) { + lua_accessor accessor = LuaAccessorFactory::make(signal); + auto inserted = accessors_.try_emplace(std::string(lua_name), std::move(accessor)); + if (!inserted.second) { + throw std::out_of_range(fmt::format( + "Failure adding lua-accessor for signal {}. Name already exists.", lua_name)); + } + } + }; + /** + * state_view of Lua + */ + std::optional lua_{}; + /** + * Instance of signals body which incorporates all bound signals + */ + std::optional signals_object_{}; + + public: + /** + * \brief Declares a DataType to Lua (if not yet done) + * \note: The function can be used independent of a bound Lua instance + */ + template + void declare_type(lua_signal_declarator_t type_declarator) { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + if (lua_.has_value()) { + std::type_index type{typeid(compatible_type)}; + auto iter = lua_declared_types_.find(type); + if (iter == lua_declared_types_.end()) { + lua_declared_types_[type] = true; + // declare type + type_declarator(*lua_); + } + } + } + + private: + /** + * \brief Declares a DataType to Lua (if not yet done) + * \note: The function can be used independent of a bound Lua instance + */ + template + void declare() { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + if (lua_.has_value()) { + // Check whether this type was already processed, if not declare it and store an adapter function in bindings_ + std::type_index type{typeid(compatible_type)}; + auto iter = bindings_.find(type); + if (iter == bindings_.end()) { + // Check wether this type was already declared to the Lua-VM, if not declare it + auto declared_types_iter = lua_declared_types_.find(type); + if (declared_types_iter == lua_declared_types_.end()) { + lua_declared_types_[type] = true; + ::cloe::databroker::detail::to_lua(*lua_); + } + + // Create adapter for Lua-VM + lua_signal_adapter_t adapter = [this](const SignalPtr& signal, sol::state_view state, + std::string_view lua_name) { + //adapter_impl(signal, state, lua_name); + // Subscribe to the value-changed event to indicate the signal is used + signal->subscribe([](const T&) {}); + // Implement the signal as a property in Lua + signals_object_->bind(signal, lua_name); + }; + // Store adapter function + bindings_.emplace(type, std::move(adapter)); + } + } + } + + public: + /** + * \brief Binds a signal to the Lua-VM + * \param signal_name Name of the signal + * \param lua_name Name of the table/variable used in Lua + * \note The bind-method needs to be invoked at least once (in total) to bring all signal bindings into effect + */ + void bind_signal(std::string_view signal_name, std::string_view lua_name) { + if (!lua_.has_value()) { + throw std::logic_error( + "DataBroker: Binding a signal to Lua must not happen, before binding the Lua " + "context."); + } + + SignalPtr signal = this->signal(signal_name); + auto type = std::type_index(*signal->type()); + + auto iter = bindings_.find(type); + if (iter == bindings_.end()) { + throw std::runtime_error( + "DataBroker: : Lua type binding not implemented"); + } + const lua_signal_adapter_t& adapter = iter->second; + adapter(signal, (*lua_), lua_name); + } + + /** + * \brief Binds a signal to the Lua-VM + * \param signal_name Name of the signal + * \note The bind-method needs to be invoked at least once (in total) to bring all signal bindings into effect + */ + void bind_signal(std::string_view signal_name) { bind_signal(signal_name, signal_name); } + + /** + * \brief Binds the signals-object to Lua + * \param signals_name Name which shall be used for the table + * \param parent_table Parent-table to use + */ + void bind(std::string_view signals_name, sol::table parent) { + parent[signals_name] = &(*signals_object_); + } + + void bind(std::string_view signals_name) { (*lua_)[signals_name] = &(*signals_object_); } + + public: + /** + * Return the signal with the given name. + * + * \param name Name of the signal + * \return Signal with the given name + */ + [[nodiscard]] const SignalContainer::const_iterator operator[](std::string_view name) const { + SignalContainer::const_iterator iter = signals_.find(name); + return iter; + } + + /** + * Return the signal with the given name. + * + * \param name Name of the signal + * \return Signal with the given name + */ + [[nodiscard]] SignalContainer::iterator operator[](std::string_view name) { + SignalContainer::iterator iter = signals_.find(name); + return iter; + } + + /** + * Give an existing signal an alias. + * + * \param signal Signal to be named + * \param new_name New name of the signal + * \return Pointer to the signal + * \note If an exception is thrown by any operation, the aliasing has no effect. + */ + SignalPtr alias(SignalPtr signal, std::string_view new_name) { + if (new_name.empty()) { + throw std::invalid_argument( + fmt::format("alias for signal must not be empty: {}", signal->name_or(""))); + } + + // Mutate signals + std::pair inserted = + signals_.try_emplace(std::string(new_name), std::move(signal)); + if (!inserted.second) { + throw std::out_of_range(fmt::format("cannot alias signal '{}' to '{}': name already exists", + signal->name_or(""), new_name)); + } + // signals mutated, there is a liability in case of exceptions + try { + SignalPtr result = inserted.first->second; + result->add_name(new_name); + return result; + } catch (...) { + // fullfill exception guarantee (map.erase(iter) does not throw) + signals_.erase(inserted.first); + throw; + } + } + + /** + * Give an existing signal a (new) name. + * + * \param old_name Name of the existing signal + * \param new_name New name of the signal + * \param f flag_type flags used to guide the interpretation of the character sequence as a regular expression + * \return Pointer to the signal + * \note If an exception is thrown by any operation, the aliasing has no effect. + */ + SignalPtr alias(std::string_view old_name, std::string_view new_name, + std::regex::flag_type f = std::regex_constants::ECMAScript) { + std::regex regex = std::regex(std::string(old_name), f); + auto it1 = signals().begin(); + auto it2 = signals().begin(); + auto end = signals().end(); + const auto predicate = [&](const auto& item) -> bool { + std::smatch match; + return std::regex_match(item.first, match, regex); + }; + it1 = (it1 != end) ? std::find_if(it1, end, predicate) : end; + it2 = (it1 != end) ? std::find_if(std::next(it1), end, predicate) : end; + if (it2 != end) { + throw std::out_of_range( + fmt::format("regex pattern matches multiple signals: '{}'; matches: '{}', '{}'", old_name, + it1->first, it2->first)); + } + if (it1 == end) { + throw std::out_of_range(fmt::format("regex pattern matches zero signals: {}", old_name)); + return nullptr; + } + SignalPtr result = alias(it1->second, new_name); + return result; + } + + /** + * Declare a new signal. + * + * \tparam T Type of the signal + * \param name Name of the signal + * \return Pointer to the specified signal + */ + template + SignalPtr declare(std::string_view new_name) { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + declare(); + + SignalPtr signal = Signal::make(); + alias(signal, new_name); + return signal; + } + + /** + * Declare a new signal. + * + * \tparam T Type of the signal + * \param name Name of the signal + * \return Container storing the signal value + */ + template + [[nodiscard]] Container> implement(std::string_view new_name) { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + declare(); + + SignalPtr signal = declare(new_name); + Container container = signal->create_container(); + return container; + } + /** + * Return the signal with the given name. + * + * \param name Name of the signal + * \return Signal with the given name + */ + [[nodiscard]] SignalPtr signal(std::string_view name) const { + auto iter = (*this)[name]; + if (iter != signals_.end()) { + return iter->second; + } + throw std::out_of_range(fmt::format("signal not found: {}", name)); + } + + /** + * Return the signal with the given name. + * + * \param name Name of the signal + * \return Signal with the given name + */ + [[nodiscard]] SignalPtr signal(std::string_view name) { + auto iter = (*this)[name]; + if (iter != signals_.end()) { + return iter->second; + } + throw std::out_of_range(fmt::format("signal not found: {}", name)); + } + + /** + * Return all signals. + */ + [[nodiscard]] const SignalContainer& signals() const { return signals_; } + + /** + * Return all signals. + */ + [[nodiscard]] SignalContainer& signals() { + // DRY + return const_cast(const_cast(this)->signals()); + } + + /** + * Subscribe to value-changed events. + * + * \tparam T Type of the signal + * \param name Name of the signal + * \param callback event-function which will be called when the value changed + */ + template + void subscribe(std::string_view name, databroker::on_value_changed_callback_t callback) { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + signal(name)->template subscribe(std::move(callback)); + } + + /** + * Set the value of a signal. + * + * \tparam T Type of the signal + * \param name Name of the signal + * \param value Value to be assigned to the signal + */ + template + void set_value(std::string_view name, databroker::signal_type_cref_t value) { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + signal(name)->set_value(value); + } + + /** + * Return the value of a signal. + * + * \tparam T Type of the signal + * \param name Name of the signal + * \return Pointer to the value of the signal + * \note databroker::compatible_base_t == T, if the function compiles + */ + template + databroker::signal_type_cref_t value(std::string_view name) const { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + return signal(name)->value(); + } + + /** + * Return the getter-function of a signal. + * + * \tparam T Type of the signal + * \param name Name of the signal + * \return getter-function of the signal + */ + template + const Signal::typed_get_value_function_t& getter(std::string_view name) const { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + const Signal::typed_get_value_function_t* getter_fn = + signal(name)->getter(); + if (!getter_fn) { + throw std::logic_error(fmt::format("getter for signal not provided: {}", name)); + } + return *getter_fn; + } + /** + * Sets the getter-function of a signal. + * + * \tparam T Type of the signal + * \param name Name of the signal + * \param getter_fn getter-function of the signal + */ + template + void set_getter(std::string_view name, const Signal::typed_get_value_function_t& getter_fn) { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + signal(name)->set_getter(getter_fn); + } + + /** + * Return the setter-function of a signal. + * + * \tparam T Type of the signal + * \param name Name of the signal + * \return const Signal::typed_set_value_function_t&, setter-function of the signal + */ + template + const Signal::typed_set_value_function_t& setter(std::string_view name) const { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + const Signal::typed_set_value_function_t* setter_fn = + signal(name)->setter(); + if (!setter_fn) { + throw std::logic_error(fmt::format("setter for signal not provided: {}", name)); + } + return *setter_fn; + } + /** + * Sets the setter-function of a signal. + * + * \tparam T Type of the signal + * \param name Name of the signal + * \param getter_fn setter-function of the signal + */ + template + void set_setter(std::string_view name, const Signal::typed_set_value_function_t& setter_fn) { + assert_static_type(); + using compatible_type = databroker::compatible_base_t; + + signal(name)->set_setter(setter_fn); + } +}; + +namespace databroker { + +struct DynamicName { + static constexpr bool STATIC = false; + + private: + std::string name_; + + public: + DynamicName(std::string name) : name_{name} {} + const std::string& name() const { return name_; } +}; + +template +struct StaticName { + static constexpr bool STATIC = true; + static constexpr const char* name() { return NAME; } +}; + +/** + * SignalDescriptorBase implements a SignalDescriptor + * + * \tparam T Type of the signal + * \tparam TNAME Name of the signal + * \tparam bool true for names, false otherwise + */ +template +struct SignalDescriptorBase {}; + +/** + * SignalDescriptorBase implements a SignalDescriptor, specialization for statically determined signal names + * + * \tparam T Type of the signal + * \tparam TNAME Name of the signal + */ +template +struct SignalDescriptorBase : public TNAME { + using TNAME::name; + using TNAME::TNAME; + /** + * Implements the signal + * + * \param db Instance of the DataBroker + * \return Container, the container of the signal + */ + static auto implement(DataBroker& db) { return db.implement(name()); } + /** + * Declares the signal + * + * \param db Instance of the DataBroker + * \return TypeSignal, the signal + */ + static void declare(DataBroker& db) { db.declare(name()); } + /** + * Returns the instance of a signal. + * + * \param db Instance of the DataBroker + * \return TypedSignal, instance of the signal + */ + static auto signal(const DataBroker& db) { return TypedSignal(db.signal(name())); } + /** + * Return the getter-function of a signal. + * + * \param db Instance of the DataBroker + * \return const Signal::typed_get_value_function_t&, getter-function of the signal + */ + static auto getter(const DataBroker& db) { return db.getter(name()); } + /** + * Sets the getter-function of a signal. + * + * \param db Instance of the DataBroker + * \param get_value_fn getter-function of the signal + */ + static void set_getter(DataBroker& db, Signal::typed_get_value_function_t get_value_fn) { + db.set_getter(name(), std::move(get_value_fn)); + } + /** + * Return the setter-function of a signal. + * + * \param db Instance of the DataBroker + * \return const Signal::typed_set_value_function_t&, setter-function of the signal + */ + static auto setter(const DataBroker& db) { return db.setter(name()); } + /** + * Sets the setter-function of a signal. + * + * \param db Instance of the DataBroker + * \param set_value_fn setter-function of the signal + */ + static void set_setter(DataBroker& db, Signal::typed_set_value_function_t set_value_fn) { + db.set_setter(name(), std::move(set_value_fn)); + } + + /** + * Return the value of a signal. + * + * \param db Instance of the DataBroker + * \return Pointer to the value of the signal, nullptr if the signal does not exist + */ + static auto value(DataBroker& db) { return db.value(name()); } + /** + * Set the value of a signal. + * + * \param db Instance of the DataBroker + * \param value Value to be assigned to the signal + */ + static auto set_value(DataBroker& db, const T& value) { db.set_value(name(), value); } +}; + +/** + * SignalDescriptorBase implements a SignalDescriptor, specialization for dynamically determined signal names + * + * \tparam T Type of the signal + */ +template +struct SignalDescriptorBase : public TNAME { + using TNAME::name; + using TNAME::TNAME; + /** + * Implements the signal + * + * \param db Instance of the DataBroker + * \return Container, the container of the signal + */ + auto implement(DataBroker& db) { return db.implement(name()); } + /** + * Declares the signal + * + * \param db Instance of the DataBroker + * \return TypeSignal, the signal + */ + void declare(DataBroker& db) { db.declare(name()); } + /** + * Returns the instance of a signal. + * + * \param db Instance of the DataBroker + * \return TypedSignal, instance of the signal + */ + auto signal(const DataBroker& db) const { return TypedSignal(db.signal(name())); } + /** + * Return the getter-function of a signal. + * + * \param db Instance of the DataBroker + * \return const Signal::typed_get_value_function_t&, getter-function of the signal + */ + auto getter(const DataBroker& db) const { return db.getter(name()); } + /** + * Sets the getter-function of a signal. + * + * \param db Instance of the DataBroker + * \param get_value_fn getter-function of the signal + */ + void set_getter(DataBroker& db, Signal::typed_get_value_function_t get_value_fn) { + db.set_getter(name(), std::move(get_value_fn)); + } + /** + * Return the setter-function of a signal. + * + * \param db Instance of the DataBroker + * \return const Signal::typed_set_value_function_t&, setter-function of the signal + */ + auto setter(const DataBroker& db) const { return db.setter(name()); } + /** + * Sets the setter-function of a signal. + * + * \param db Instance of the DataBroker + * \param set_value_fn setter-function of the signal + */ + void set_setter(DataBroker& db, Signal::typed_set_value_function_t set_value_fn) { + db.set_setter(name(), std::move(set_value_fn)); + } + + /** + * Return the value of a signal. + * + * \param db Instance of the DataBroker + * \return Pointer to the value of the signal, nullptr if the signal does not exist + */ + auto value(const DataBroker& db) const { return db.value(name()); } + /** + * Set the value of a signal. + * + * \param db Instance of the DataBroker + * \param value Value to be assigned to the signal + */ + auto set_value(DataBroker& db, const T& value) { db.set_value(name(), value); } +}; + +/** + * SignalDescriptorImpl implements a SignalDescriptor for names + * + * \tparam T Type of the signal + */ +template +struct SignalDescriptorImpl : public SignalDescriptorBase> { + using SignalDescriptorBase>::SignalDescriptorBase; +}; +/** + * SignalDescriptorImpl implements a SignalDescriptor for dynamic names + * + * \tparam T Type of the signal + */ +template +struct SignalDescriptorImpl : public SignalDescriptorBase { + using SignalDescriptorBase::SignalDescriptorBase; +}; + +/** + * SignalDescriptor reflects properties of a signal at compile-/run-time + * + * \tparam T Type of the signal + * \tparam NAME compile-time name, nullptr otherwise + * \note: Design-Goals: + * - Design-#1: Datatype of a signal shall be available at compile time + * \note: Remarks: + * - The declaration of a descriptor does not imply the availability of the coresponding signal at runtime. + * Likewise a C/C++ header does not imply that the coresponding symbols can be resolved at runtime. + */ +template +struct SignalDescriptor : public SignalDescriptorImpl { + using SignalDescriptorImpl::SignalDescriptorImpl; +}; + +template +struct SignalTemplate : public StaticName { + private: + template