Proper LuaString <-> String handling

Null-terminator usage is opt-in
This commit is contained in:
Kae 2023-07-24 23:37:55 +10:00
parent 4c636e911c
commit 8547c56ba4
7 changed files with 125 additions and 17 deletions

View File

@ -115,6 +115,11 @@ LuaTable LuaContext::createTable() {
return engine().createTable(); return engine().createTable();
} }
LuaNullEnforcer::LuaNullEnforcer(LuaEngine& engine)
: m_engine(&engine) { ++m_engine->m_nullTerminated; };
LuaNullEnforcer::~LuaNullEnforcer() { --m_engine->m_nullTerminated; };
LuaValue LuaConverter<Json>::from(LuaEngine& engine, Json const& v) { LuaValue LuaConverter<Json>::from(LuaEngine& engine, Json const& v) {
if (v.isType(Json::Type::Null)) { if (v.isType(Json::Type::Null)) {
return LuaNil; return LuaNil;
@ -202,6 +207,7 @@ LuaEnginePtr LuaEngine::create(bool safe) {
self->m_instructionCount = 0; self->m_instructionCount = 0;
self->m_recursionLevel = 0; self->m_recursionLevel = 0;
self->m_recursionLimit = 0; self->m_recursionLimit = 0;
self->m_nullTerminated = 0;
if (!self->m_state) if (!self->m_state)
throw LuaException("Failed to initialize Lua"); throw LuaException("Failed to initialize Lua");
@ -455,6 +461,9 @@ ByteArray LuaEngine::compile(ByteArray const& contents, String const& name) {
LuaString LuaEngine::createString(String const& str) { LuaString LuaEngine::createString(String const& str) {
lua_checkstack(m_state, 1); lua_checkstack(m_state, 1);
if (m_nullTerminated)
lua_pushstring(m_state, str.utf8Ptr());
else
lua_pushlstring(m_state, str.utf8Ptr(), str.utf8Size()); lua_pushlstring(m_state, str.utf8Ptr(), str.utf8Size());
return LuaString(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(m_state))); return LuaString(LuaDetail::LuaHandle(RefPtr<LuaEngine>(this), popHandle(m_state)));
} }
@ -543,6 +552,10 @@ size_t LuaEngine::memoryUsage() const {
return (size_t)lua_gc(m_state, LUA_GCCOUNT, 0) * 1024 + lua_gc(m_state, LUA_GCCOUNTB, 0); return (size_t)lua_gc(m_state, LUA_GCCOUNT, 0) * 1024 + lua_gc(m_state, LUA_GCCOUNTB, 0);
} }
LuaNullEnforcer LuaEngine::nullTerminate() {
return LuaNullEnforcer(*this);
}
LuaEngine* LuaEngine::luaEnginePtr(lua_State* state) { LuaEngine* LuaEngine::luaEnginePtr(lua_State* state) {
return (*reinterpret_cast<LuaEngine**>(lua_getextraspace(state))); return (*reinterpret_cast<LuaEngine**>(lua_getextraspace(state)));
} }
@ -710,9 +723,33 @@ char const* LuaEngine::stringPtr(int handleIndex) {
} }
size_t LuaEngine::stringLength(int handleIndex) { size_t LuaEngine::stringLength(int handleIndex) {
if (m_nullTerminated)
return strlen(lua_tostring(m_handleThread, handleIndex));
else {
size_t len = 0; size_t len = 0;
lua_tolstring(m_handleThread, handleIndex, &len); lua_tolstring(m_handleThread, handleIndex, &len);
return len; return len;
}
}
String LuaEngine::string(int handleIndex) {
if (m_nullTerminated)
return String(lua_tostring(m_handleThread, handleIndex));
else {
size_t len = 0;
const char* data = lua_tolstring(m_handleThread, handleIndex, &len);
return String(data, len);
}
}
StringView LuaEngine::stringView(int handleIndex) {
if (m_nullTerminated)
return StringView(lua_tostring(m_handleThread, handleIndex));
else {
size_t len = 0;
const char* data = lua_tolstring(m_handleThread, handleIndex, &len);
return StringView(data, len);
}
} }
LuaValue LuaEngine::tableGet(bool raw, int handleIndex, LuaValue const& key) { LuaValue LuaEngine::tableGet(bool raw, int handleIndex, LuaValue const& key) {

View File

@ -156,6 +156,7 @@ public:
size_t length() const; size_t length() const;
String toString() const; String toString() const;
StringView view() const;
}; };
bool operator==(LuaString const& s1, LuaString const& s2); bool operator==(LuaString const& s1, LuaString const& s2);
@ -383,6 +384,38 @@ public:
LuaUserData createUserData(T t); LuaUserData createUserData(T t);
}; };
template <typename T>
struct LuaNullTermWrapper : T {
LuaNullTermWrapper() : T() {}
LuaNullTermWrapper(LuaNullTermWrapper const& nt) : T(nt) {}
LuaNullTermWrapper(LuaNullTermWrapper&& nt) : T(std::move(nt)) {}
LuaNullTermWrapper(T const& bt) : T(bt) {}
LuaNullTermWrapper(T&& bt) : T(std::move(bt)) {}
LuaNullTermWrapper& operator=(LuaNullTermWrapper const& rhs) {
T::operator=(rhs);
return *this;
}
LuaNullTermWrapper& operator=(LuaNullTermWrapper&& rhs) {
T::operator=(std::move(rhs));
return *this;
}
LuaNullTermWrapper& operator=(T&& other) {
T::operator=(std::forward<T>(other));
return *this;
}
};
class LuaNullEnforcer {
public:
LuaNullEnforcer(LuaEngine& engine);
~LuaNullEnforcer();
private:
LuaEngine* m_engine;
};
// Types that want to participate in automatic lua conversion should specialize // Types that want to participate in automatic lua conversion should specialize
// this template and provide static to and from methods on it. The method // this template and provide static to and from methods on it. The method
// signatures will be called like: // signatures will be called like:
@ -565,6 +598,9 @@ public:
// Bytes in use by lua // Bytes in use by lua
size_t memoryUsage() const; size_t memoryUsage() const;
// Enforce null-terminated string conversion as long as the returned enforcer object is in scope.
LuaNullEnforcer nullTerminate();
private: private:
friend struct LuaDetail::LuaHandle; friend struct LuaDetail::LuaHandle;
friend class LuaReference; friend class LuaReference;
@ -574,6 +610,7 @@ private:
friend class LuaThread; friend class LuaThread;
friend class LuaUserData; friend class LuaUserData;
friend class LuaContext; friend class LuaContext;
friend class LuaNullEnforcer;
LuaEngine() = default; LuaEngine() = default;
@ -599,6 +636,8 @@ private:
char const* stringPtr(int handleIndex); char const* stringPtr(int handleIndex);
size_t stringLength(int handleIndex); size_t stringLength(int handleIndex);
String string(int handleIndex);
StringView stringView(int handleIndex);
LuaValue tableGet(bool raw, int handleIndex, LuaValue const& key); LuaValue tableGet(bool raw, int handleIndex, LuaValue const& key);
LuaValue tableGet(bool raw, int handleIndex, char const* key); LuaValue tableGet(bool raw, int handleIndex, char const* key);
@ -690,6 +729,7 @@ private:
uint64_t m_instructionCount; uint64_t m_instructionCount;
unsigned m_recursionLevel; unsigned m_recursionLevel;
unsigned m_recursionLimit; unsigned m_recursionLimit;
unsigned m_nullTerminated;
HashMap<tuple<String, unsigned>, shared_ptr<LuaProfileEntry>> m_profileEntries; HashMap<tuple<String, unsigned>, shared_ptr<LuaProfileEntry>> m_profileEntries;
}; };
@ -796,7 +836,7 @@ struct LuaConverter<String> {
static Maybe<String> to(LuaEngine&, LuaValue const& v) { static Maybe<String> to(LuaEngine&, LuaValue const& v) {
if (v.is<LuaString>()) if (v.is<LuaString>())
return String(v.get<LuaString>().ptr()); return v.get<LuaString>().toString();
if (v.is<LuaInt>()) if (v.is<LuaInt>())
return String(toString(v.get<LuaInt>())); return String(toString(v.get<LuaInt>()));
if (v.is<LuaFloat>()) if (v.is<LuaFloat>())
@ -1579,35 +1619,39 @@ inline size_t LuaString::length() const {
} }
inline String LuaString::toString() const { inline String LuaString::toString() const {
return String(ptr()); return engine().string(handleIndex());
}
inline StringView LuaString::view() const {
return engine().stringView(handleIndex());
} }
inline bool operator==(LuaString const& s1, LuaString const& s2) { inline bool operator==(LuaString const& s1, LuaString const& s2) {
return std::strcmp(s1.ptr(), s2.ptr()) == 0; return s1.view() == s2.view();
} }
inline bool operator==(LuaString const& s1, char const* s2) { inline bool operator==(LuaString const& s1, char const* s2) {
return std::strcmp(s1.ptr(), s2) == 0; return s1.view() == s2;
} }
inline bool operator==(LuaString const& s1, std::string const& s2) { inline bool operator==(LuaString const& s1, std::string const& s2) {
return s1.ptr() == s2; return s1.view() == s2;
} }
inline bool operator==(LuaString const& s1, String const& s2) { inline bool operator==(LuaString const& s1, String const& s2) {
return s1.ptr() == s2; return s1.view() == s2;
} }
inline bool operator==(char const* s1, LuaString const& s2) { inline bool operator==(char const* s1, LuaString const& s2) {
return std::strcmp(s1, s2.ptr()) == 0; return s2.view() == s1;
} }
inline bool operator==(std::string const& s1, LuaString const& s2) { inline bool operator==(std::string const& s1, LuaString const& s2) {
return s1 == s2.ptr(); return s2.view() == s1;
} }
inline bool operator==(String const& s1, LuaString const& s2) { inline bool operator==(String const& s1, LuaString const& s2) {
return s1 == s2.ptr(); return s2.view() == s1;
} }
inline bool operator!=(LuaString const& s1, LuaString const& s2) { inline bool operator!=(LuaString const& s1, LuaString const& s2) {

View File

@ -11,6 +11,23 @@
namespace Star { namespace Star {
template <typename T>
struct LuaConverter<LuaNullTermWrapper<T>> : LuaConverter<T> {
static LuaValue from(LuaEngine& engine, LuaNullTermWrapper<T>&& v) {
auto enforcer = engine.nullTerminate();
return LuaConverter<T>::from(std::forward<T>(v));
}
static LuaValue from(LuaEngine& engine, LuaNullTermWrapper<T> const& v) {
auto enforcer = engine.nullTerminate();
return LuaConverter<T>::from(v);
}
static LuaNullTermWrapper<T> to(LuaEngine& engine, LuaValue const& v) {
return LuaConverter<T>::to(v);
}
};
template <typename T1, typename T2> template <typename T1, typename T2>
struct LuaConverter<pair<T1, T2>> { struct LuaConverter<pair<T1, T2>> {
static LuaValue from(LuaEngine& engine, pair<T1, T2>&& v) { static LuaValue from(LuaEngine& engine, pair<T1, T2>&& v) {

View File

@ -418,6 +418,14 @@ bool operator==(StringView s1, const char* s2) {
return s1.m_view.compare(s2) == 0; return s1.m_view.compare(s2) == 0;
} }
bool operator==(StringView s1, std::string const& s2) {
return s1.m_view.compare(s2) == 0;
}
bool operator==(StringView s1, String const& s2) {
return s1.m_view.compare(s2.utf8()) == 0;
}
bool operator==(StringView s1, StringView s2) { bool operator==(StringView s1, StringView s2) {
return s1.m_view == s2.m_view; return s1.m_view == s2.m_view;
} }

View File

@ -96,6 +96,8 @@ public:
StringView& operator=(StringView s); StringView& operator=(StringView s);
friend bool operator==(StringView s1, const char* s2); friend bool operator==(StringView s1, const char* s2);
friend bool operator==(StringView s1, std::string const& s2);
friend bool operator==(StringView s1, String const& s2);
friend bool operator==(StringView s1, StringView s2); friend bool operator==(StringView s1, StringView s2);
friend bool operator!=(StringView s1, StringView s2); friend bool operator!=(StringView s1, StringView s2);
friend bool operator<(StringView s1, StringView s2); friend bool operator<(StringView s1, StringView s2);

View File

@ -1421,7 +1421,7 @@ namespace LuaBindings {
return {}; return {};
} }
Maybe<JsonArray> WorldEntityCallbacks::entityPortrait(World* world, EntityId entityId, String const& portraitMode) { LuaNullTermWrapper<Maybe<JsonArray>> WorldEntityCallbacks::entityPortrait(World* world, EntityId entityId, String const& portraitMode) {
auto entity = world->entity(entityId); auto entity = world->entity(entityId);
if (auto portraitEntity = as<PortraitEntity>(entity)) { if (auto portraitEntity = as<PortraitEntity>(entity)) {
@ -1471,7 +1471,7 @@ namespace LuaBindings {
return Json(); return Json();
} }
Maybe<String> WorldEntityCallbacks::entityUniqueId(World* world, EntityId entityId) { LuaNullTermWrapper<Maybe<String>> WorldEntityCallbacks::entityUniqueId(World* world, EntityId entityId) {
if (auto entity = world->entity(entityId)) if (auto entity = world->entity(entityId))
return entity->uniqueId(); return entity->uniqueId();
return {}; return {};

View File

@ -122,10 +122,10 @@ namespace LuaBindings {
Maybe<String> entityGender(World* world, EntityId entityId); Maybe<String> entityGender(World* world, EntityId entityId);
Maybe<String> entityName(World* world, EntityId entityId); Maybe<String> entityName(World* world, EntityId entityId);
Maybe<String> entityDescription(World* world, EntityId entityId, Maybe<String> const& species); Maybe<String> entityDescription(World* world, EntityId entityId, Maybe<String> const& species);
Maybe<JsonArray> entityPortrait(World* world, EntityId entityId, String const& portraitMode); LuaNullTermWrapper<Maybe<JsonArray>> entityPortrait(World* world, EntityId entityId, String const& portraitMode);
Maybe<String> entityHandItem(World* world, EntityId entityId, String const& handName); Maybe<String> entityHandItem(World* world, EntityId entityId, String const& handName);
Json entityHandItemDescriptor(World* world, EntityId entityId, String const& handName); Json entityHandItemDescriptor(World* world, EntityId entityId, String const& handName);
Maybe<String> entityUniqueId(World* world, EntityId entityId); LuaNullTermWrapper<Maybe<String>> entityUniqueId(World* world, EntityId entityId);
Json getObjectParameter(World* world, EntityId entityId, String const& parameterName, Maybe<Json> const& defaultValue); Json getObjectParameter(World* world, EntityId entityId, String const& parameterName, Maybe<Json> const& defaultValue);
Json getNpcScriptParameter(World* world, EntityId entityId, String const& parameterName, Maybe<Json> const& defaultValue); Json getNpcScriptParameter(World* world, EntityId entityId, String const& parameterName, Maybe<Json> const& defaultValue);
List<Vec2I> objectSpaces(World* world, EntityId entityId); List<Vec2I> objectSpaces(World* world, EntityId entityId);