From 0f2eedd4fb3946ed2a776b3baccf2ea0cd98cf6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20B=C3=A1lint=20Misius?= Date: Sat, 23 Oct 2021 09:38:58 +0200 Subject: [PATCH] Make thread_local replacement for MinGW slightly nicer --- src/common/tpt-thread-local.cpp | 78 +++++++++++++-------------------- src/common/tpt-thread-local.h | 63 +++++++++++++++++--------- 2 files changed, 73 insertions(+), 68 deletions(-) diff --git a/src/common/tpt-thread-local.cpp b/src/common/tpt-thread-local.cpp index cde7ec95c..0a53422c1 100644 --- a/src/common/tpt-thread-local.cpp +++ b/src/common/tpt-thread-local.cpp @@ -5,58 +5,41 @@ # include # include -static pthread_once_t once = PTHREAD_ONCE_INIT; -static pthread_key_t key; - -struct ThreadLocalCommon +void *ThreadLocalCommon::Get() const { - size_t size; - void (*ctor)(void *); - void (*dtor)(void *); - size_t padding; -}; -static_assert(sizeof(ThreadLocalCommon) == 0x20, "fix me"); + // https://stackoverflow.com/questions/16552710/how-do-you-get-the-start-and-end-addresses-of-a-custom-elf-section + extern ThreadLocalCommon __start_tpt_tls; + extern ThreadLocalCommon __stop_tpt_tls; + static pthread_once_t once = PTHREAD_ONCE_INIT; + static pthread_key_t key; -struct ThreadLocalEntry -{ - void *ptr; -}; - -// https://stackoverflow.com/questions/16552710/how-do-you-get-the-start-and-end-addresses-of-a-custom-elf-section -extern ThreadLocalCommon __start_tpt_tls; -extern ThreadLocalCommon __stop_tpt_tls; - -static void ThreadLocalDestroy(void *opaque) -{ - auto *staticsBegin = &__start_tpt_tls; - auto *staticsEnd = &__stop_tpt_tls; - auto staticsCount = staticsEnd - staticsBegin; - auto *liveObjects = reinterpret_cast(opaque); - if (liveObjects) + struct ThreadLocalEntry { - for (auto i = 0; i < staticsCount; ++i) - { - if (liveObjects[i].ptr) - { - staticsBegin[i].dtor(liveObjects[i].ptr); - free(liveObjects[i].ptr); - } - } - free(liveObjects); - } -} + void *ptr; + }; -static void ThreadLocalCreate() -{ - assert(!pthread_key_create(&key, ThreadLocalDestroy)); -} - -void *ThreadLocalGet(void *opaque) -{ auto *staticsBegin = &__start_tpt_tls; auto *staticsEnd = &__stop_tpt_tls; - auto *staticsOpaque = reinterpret_cast(opaque); - pthread_once(&once, ThreadLocalCreate); + pthread_once(&once, []() -> void { + assert(!pthread_key_create(&key, [](void *opaque) -> void { + auto *staticsBegin = &__start_tpt_tls; + auto *staticsEnd = &__stop_tpt_tls; + auto staticsCount = staticsEnd - staticsBegin; + auto *liveObjects = reinterpret_cast(opaque); + if (liveObjects) + { + for (auto i = 0; i < staticsCount; ++i) + { + if (liveObjects[i].ptr) + { + staticsBegin[i].dtor(liveObjects[i].ptr); + free(liveObjects[i].ptr); + } + } + free(liveObjects); + } + })); + }); auto *liveObjects = reinterpret_cast(pthread_getspecific(key)); if (!liveObjects) { @@ -65,7 +48,7 @@ void *ThreadLocalGet(void *opaque) assert(liveObjects); assert(!pthread_setspecific(key, reinterpret_cast(liveObjects))); } - auto idx = staticsOpaque - staticsBegin; + auto idx = this - staticsBegin; auto &entry = liveObjects[idx]; if (!entry.ptr) { @@ -75,5 +58,4 @@ void *ThreadLocalGet(void *opaque) } return entry.ptr; } - #endif diff --git a/src/common/tpt-thread-local.h b/src/common/tpt-thread-local.h index 01a05c02e..e6b6586f6 100644 --- a/src/common/tpt-thread-local.h +++ b/src/common/tpt-thread-local.h @@ -4,39 +4,62 @@ #ifdef __MINGW32__ # include -template -class ThreadLocal +class ThreadLocalCommon { - static void Ctor(Type *type) - { - new (type) Type(); - } + ThreadLocalCommon(const ThreadLocalCommon &other) = delete; + ThreadLocalCommon &operator =(const ThreadLocalCommon &other) = delete; - static void Dtor(Type *type) - { - type->~Type(); - } - - size_t size = sizeof(Type); - void (*ctor)(Type *) = Ctor; - void (*dtor)(Type *) = Dtor; +protected: + size_t size; + void (*ctor)(void *); + void (*dtor)(void *); size_t padding; + void *Get() const; + public: - Type *operator &() + ThreadLocalCommon() = default; + + static constexpr size_t Alignment = 0x20; +}; +// * If this fails, add or remove padding fields, possibly change Alignment to a larger power of 2. +static_assert(sizeof(ThreadLocalCommon) == ThreadLocalCommon::Alignment, "fix me"); + +template +class ThreadLocal : public ThreadLocalCommon +{ + static void Ctor(void *type) { - static_assert(sizeof(ThreadLocal) == 0x20, "fix me"); - void *ThreadLocalGet(void *opaque); - return reinterpret_cast(ThreadLocalGet(reinterpret_cast(this))); + new(type) Type(); } - operator Type &() + static void Dtor(void *type) + { + reinterpret_cast(type)->~Type(); + } + +public: + ThreadLocal() + { + // * If this fails, you're out of luck. + static_assert(sizeof(ThreadLocal) == sizeof(ThreadLocalCommon), "fix me"); + size = sizeof(Type); + ctor = Ctor; + dtor = Dtor; + } + + Type *operator &() const + { + return reinterpret_cast(Get()); + } + + operator Type &() const { return *(this->operator &()); } }; -# define THREAD_LOCAL(Type, tl) ThreadLocal tl __attribute__((section("tpt_tls"))) __attribute__((aligned(0x20))) +# define THREAD_LOCAL(Type, tl) const ThreadLocal tl __attribute__((section("tpt_tls"), aligned(ThreadLocalCommon::Alignment))) #else # define THREAD_LOCAL(Type, tl) thread_local Type tl #endif