diff --git a/source/core/StarLua.cpp b/source/core/StarLua.cpp index 287bf5c..901e875 100644 --- a/source/core/StarLua.cpp +++ b/source/core/StarLua.cpp @@ -115,6 +115,11 @@ LuaTable LuaContext::createTable() { return engine().createTable(); } +LuaNullEnforcer::LuaNullEnforcer(LuaEngine& engine) + : m_engine(&engine) { ++m_engine->m_nullTerminated; }; + +LuaNullEnforcer::~LuaNullEnforcer() { --m_engine->m_nullTerminated; }; + LuaValue LuaConverter::from(LuaEngine& engine, Json const& v) { if (v.isType(Json::Type::Null)) { return LuaNil; @@ -202,6 +207,7 @@ LuaEnginePtr LuaEngine::create(bool safe) { self->m_instructionCount = 0; self->m_recursionLevel = 0; self->m_recursionLimit = 0; + self->m_nullTerminated = 0; if (!self->m_state) throw LuaException("Failed to initialize Lua"); @@ -455,7 +461,10 @@ ByteArray LuaEngine::compile(ByteArray const& contents, String const& name) { LuaString LuaEngine::createString(String const& str) { lua_checkstack(m_state, 1); - lua_pushlstring(m_state, str.utf8Ptr(), str.utf8Size()); + if (m_nullTerminated) + lua_pushstring(m_state, str.utf8Ptr()); + else + lua_pushlstring(m_state, str.utf8Ptr(), str.utf8Size()); return LuaString(LuaDetail::LuaHandle(RefPtr(this), popHandle(m_state))); } @@ -543,6 +552,10 @@ size_t LuaEngine::memoryUsage() const { return (size_t)lua_gc(m_state, LUA_GCCOUNT, 0) * 1024 + lua_gc(m_state, LUA_GCCOUNTB, 0); } +LuaNullEnforcer LuaEngine::nullTerminate() { + return LuaNullEnforcer(*this); +} + LuaEngine* LuaEngine::luaEnginePtr(lua_State* state) { return (*reinterpret_cast(lua_getextraspace(state))); } @@ -710,9 +723,33 @@ char const* LuaEngine::stringPtr(int handleIndex) { } size_t LuaEngine::stringLength(int handleIndex) { - size_t len = 0; - lua_tolstring(m_handleThread, handleIndex, &len); - return len; + if (m_nullTerminated) + return strlen(lua_tostring(m_handleThread, handleIndex)); + else { + size_t len = 0; + lua_tolstring(m_handleThread, handleIndex, &len); + return len; + } +} + +String LuaEngine::string(int handleIndex) { + if (m_nullTerminated) + return String(lua_tostring(m_handleThread, handleIndex)); + else { + size_t len = 0; + const char* data = lua_tolstring(m_handleThread, handleIndex, &len); + return String(data, len); + } +} + +StringView LuaEngine::stringView(int handleIndex) { + if (m_nullTerminated) + return StringView(lua_tostring(m_handleThread, handleIndex)); + else { + size_t len = 0; + const char* data = lua_tolstring(m_handleThread, handleIndex, &len); + return StringView(data, len); + } } LuaValue LuaEngine::tableGet(bool raw, int handleIndex, LuaValue const& key) { diff --git a/source/core/StarLua.hpp b/source/core/StarLua.hpp index dda0802..0deb76f 100644 --- a/source/core/StarLua.hpp +++ b/source/core/StarLua.hpp @@ -156,6 +156,7 @@ public: size_t length() const; String toString() const; + StringView view() const; }; bool operator==(LuaString const& s1, LuaString const& s2); @@ -383,6 +384,38 @@ public: LuaUserData createUserData(T t); }; +template +struct LuaNullTermWrapper : T { + LuaNullTermWrapper() : T() {} + LuaNullTermWrapper(LuaNullTermWrapper const& nt) : T(nt) {} + LuaNullTermWrapper(LuaNullTermWrapper&& nt) : T(std::move(nt)) {} + LuaNullTermWrapper(T const& bt) : T(bt) {} + LuaNullTermWrapper(T&& bt) : T(std::move(bt)) {} + + LuaNullTermWrapper& operator=(LuaNullTermWrapper const& rhs) { + T::operator=(rhs); + return *this; + } + + LuaNullTermWrapper& operator=(LuaNullTermWrapper&& rhs) { + T::operator=(std::move(rhs)); + return *this; + } + + LuaNullTermWrapper& operator=(T&& other) { + T::operator=(std::forward(other)); + return *this; + } +}; + +class LuaNullEnforcer { +public: + LuaNullEnforcer(LuaEngine& engine); + ~LuaNullEnforcer(); +private: + LuaEngine* m_engine; +}; + // Types that want to participate in automatic lua conversion should specialize // this template and provide static to and from methods on it. The method // signatures will be called like: @@ -565,6 +598,9 @@ public: // Bytes in use by lua size_t memoryUsage() const; + // Enforce null-terminated string conversion as long as the returned enforcer object is in scope. + LuaNullEnforcer nullTerminate(); + private: friend struct LuaDetail::LuaHandle; friend class LuaReference; @@ -574,6 +610,7 @@ private: friend class LuaThread; friend class LuaUserData; friend class LuaContext; + friend class LuaNullEnforcer; LuaEngine() = default; @@ -599,6 +636,8 @@ private: char const* stringPtr(int handleIndex); size_t stringLength(int handleIndex); + String string(int handleIndex); + StringView stringView(int handleIndex); LuaValue tableGet(bool raw, int handleIndex, LuaValue const& key); LuaValue tableGet(bool raw, int handleIndex, char const* key); @@ -690,6 +729,7 @@ private: uint64_t m_instructionCount; unsigned m_recursionLevel; unsigned m_recursionLimit; + unsigned m_nullTerminated; HashMap, shared_ptr> m_profileEntries; }; @@ -796,7 +836,7 @@ struct LuaConverter { static Maybe to(LuaEngine&, LuaValue const& v) { if (v.is()) - return String(v.get().ptr()); + return v.get().toString(); if (v.is()) return String(toString(v.get())); if (v.is()) @@ -1579,35 +1619,39 @@ inline size_t LuaString::length() const { } inline String LuaString::toString() const { - return String(ptr()); + return engine().string(handleIndex()); +} + +inline StringView LuaString::view() const { + return engine().stringView(handleIndex()); } inline bool operator==(LuaString const& s1, LuaString const& s2) { - return std::strcmp(s1.ptr(), s2.ptr()) == 0; + return s1.view() == s2.view(); } inline bool operator==(LuaString const& s1, char const* s2) { - return std::strcmp(s1.ptr(), s2) == 0; + return s1.view() == s2; } inline bool operator==(LuaString const& s1, std::string const& s2) { - return s1.ptr() == s2; + return s1.view() == s2; } inline bool operator==(LuaString const& s1, String const& s2) { - return s1.ptr() == s2; + return s1.view() == s2; } inline bool operator==(char const* s1, LuaString const& s2) { - return std::strcmp(s1, s2.ptr()) == 0; + return s2.view() == s1; } inline bool operator==(std::string const& s1, LuaString const& s2) { - return s1 == s2.ptr(); + return s2.view() == s1; } inline bool operator==(String const& s1, LuaString const& s2) { - return s1 == s2.ptr(); + return s2.view() == s1; } inline bool operator!=(LuaString const& s1, LuaString const& s2) { diff --git a/source/core/StarLuaConverters.hpp b/source/core/StarLuaConverters.hpp index eb06e35..72ab763 100644 --- a/source/core/StarLuaConverters.hpp +++ b/source/core/StarLuaConverters.hpp @@ -11,6 +11,23 @@ namespace Star { +template +struct LuaConverter> : LuaConverter { + static LuaValue from(LuaEngine& engine, LuaNullTermWrapper&& v) { + auto enforcer = engine.nullTerminate(); + return LuaConverter::from(std::forward(v)); + } + + static LuaValue from(LuaEngine& engine, LuaNullTermWrapper const& v) { + auto enforcer = engine.nullTerminate(); + return LuaConverter::from(v); + } + + static LuaNullTermWrapper to(LuaEngine& engine, LuaValue const& v) { + return LuaConverter::to(v); + } +}; + template struct LuaConverter> { static LuaValue from(LuaEngine& engine, pair&& v) { diff --git a/source/core/StarStringView.cpp b/source/core/StarStringView.cpp index 8ad3b8c..4b45fe6 100644 --- a/source/core/StarStringView.cpp +++ b/source/core/StarStringView.cpp @@ -418,6 +418,14 @@ bool operator==(StringView s1, const char* s2) { return s1.m_view.compare(s2) == 0; } +bool operator==(StringView s1, std::string const& s2) { + return s1.m_view.compare(s2) == 0; +} + +bool operator==(StringView s1, String const& s2) { + return s1.m_view.compare(s2.utf8()) == 0; +} + bool operator==(StringView s1, StringView s2) { return s1.m_view == s2.m_view; } diff --git a/source/core/StarStringView.hpp b/source/core/StarStringView.hpp index 3ebc952..5598d0f 100644 --- a/source/core/StarStringView.hpp +++ b/source/core/StarStringView.hpp @@ -96,6 +96,8 @@ public: StringView& operator=(StringView s); friend bool operator==(StringView s1, const char* s2); + friend bool operator==(StringView s1, std::string const& s2); + friend bool operator==(StringView s1, String const& s2); friend bool operator==(StringView s1, StringView s2); friend bool operator!=(StringView s1, StringView s2); friend bool operator<(StringView s1, StringView s2); diff --git a/source/game/scripting/StarWorldLuaBindings.cpp b/source/game/scripting/StarWorldLuaBindings.cpp index 6d1648e..087f88c 100644 --- a/source/game/scripting/StarWorldLuaBindings.cpp +++ b/source/game/scripting/StarWorldLuaBindings.cpp @@ -1421,7 +1421,7 @@ namespace LuaBindings { return {}; } - Maybe WorldEntityCallbacks::entityPortrait(World* world, EntityId entityId, String const& portraitMode) { + LuaNullTermWrapper> WorldEntityCallbacks::entityPortrait(World* world, EntityId entityId, String const& portraitMode) { auto entity = world->entity(entityId); if (auto portraitEntity = as(entity)) { @@ -1471,7 +1471,7 @@ namespace LuaBindings { return Json(); } - Maybe WorldEntityCallbacks::entityUniqueId(World* world, EntityId entityId) { + LuaNullTermWrapper> WorldEntityCallbacks::entityUniqueId(World* world, EntityId entityId) { if (auto entity = world->entity(entityId)) return entity->uniqueId(); return {}; diff --git a/source/game/scripting/StarWorldLuaBindings.hpp b/source/game/scripting/StarWorldLuaBindings.hpp index 725e824..e2d082c 100644 --- a/source/game/scripting/StarWorldLuaBindings.hpp +++ b/source/game/scripting/StarWorldLuaBindings.hpp @@ -122,10 +122,10 @@ namespace LuaBindings { Maybe entityGender(World* world, EntityId entityId); Maybe entityName(World* world, EntityId entityId); Maybe entityDescription(World* world, EntityId entityId, Maybe const& species); - Maybe entityPortrait(World* world, EntityId entityId, String const& portraitMode); + LuaNullTermWrapper> entityPortrait(World* world, EntityId entityId, String const& portraitMode); Maybe entityHandItem(World* world, EntityId entityId, String const& handName); Json entityHandItemDescriptor(World* world, EntityId entityId, String const& handName); - Maybe entityUniqueId(World* world, EntityId entityId); + LuaNullTermWrapper> entityUniqueId(World* world, EntityId entityId); Json getObjectParameter(World* world, EntityId entityId, String const& parameterName, Maybe const& defaultValue); Json getNpcScriptParameter(World* world, EntityId entityId, String const& parameterName, Maybe const& defaultValue); List objectSpaces(World* world, EntityId entityId);