Make thread_local replacement for MinGW slightly nicer

This commit is contained in:
Tamás Bálint Misius 2021-10-23 09:38:58 +02:00
parent 0ed8d0a0be
commit 0f2eedd4fb
No known key found for this signature in database
GPG Key ID: 5B472A12F6ECA9F2
2 changed files with 73 additions and 68 deletions

View File

@ -5,29 +5,23 @@
# include <cstdlib> # include <cstdlib>
# include <cassert> # include <cassert>
void *ThreadLocalCommon::Get() const
{
// https://stackoverflow.com/questions/16552710/how-do-you-get-the-start-and-end-addresses-of-a-custom-elf-section
extern ThreadLocalCommon __start_tpt_tls;
extern ThreadLocalCommon __stop_tpt_tls;
static pthread_once_t once = PTHREAD_ONCE_INIT; static pthread_once_t once = PTHREAD_ONCE_INIT;
static pthread_key_t key; static pthread_key_t key;
struct ThreadLocalCommon
{
size_t size;
void (*ctor)(void *);
void (*dtor)(void *);
size_t padding;
};
static_assert(sizeof(ThreadLocalCommon) == 0x20, "fix me");
struct ThreadLocalEntry struct ThreadLocalEntry
{ {
void *ptr; void *ptr;
}; };
// https://stackoverflow.com/questions/16552710/how-do-you-get-the-start-and-end-addresses-of-a-custom-elf-section auto *staticsBegin = &__start_tpt_tls;
extern ThreadLocalCommon __start_tpt_tls; auto *staticsEnd = &__stop_tpt_tls;
extern ThreadLocalCommon __stop_tpt_tls; pthread_once(&once, []() -> void {
assert(!pthread_key_create(&key, [](void *opaque) -> void {
static void ThreadLocalDestroy(void *opaque)
{
auto *staticsBegin = &__start_tpt_tls; auto *staticsBegin = &__start_tpt_tls;
auto *staticsEnd = &__stop_tpt_tls; auto *staticsEnd = &__stop_tpt_tls;
auto staticsCount = staticsEnd - staticsBegin; auto staticsCount = staticsEnd - staticsBegin;
@ -44,19 +38,8 @@ static void ThreadLocalDestroy(void *opaque)
} }
free(liveObjects); free(liveObjects);
} }
} }));
});
static void ThreadLocalCreate()
{
assert(!pthread_key_create(&key, ThreadLocalDestroy));
}
void *ThreadLocalGet(void *opaque)
{
auto *staticsBegin = &__start_tpt_tls;
auto *staticsEnd = &__stop_tpt_tls;
auto *staticsOpaque = reinterpret_cast<ThreadLocalCommon *>(opaque);
pthread_once(&once, ThreadLocalCreate);
auto *liveObjects = reinterpret_cast<ThreadLocalEntry *>(pthread_getspecific(key)); auto *liveObjects = reinterpret_cast<ThreadLocalEntry *>(pthread_getspecific(key));
if (!liveObjects) if (!liveObjects)
{ {
@ -65,7 +48,7 @@ void *ThreadLocalGet(void *opaque)
assert(liveObjects); assert(liveObjects);
assert(!pthread_setspecific(key, reinterpret_cast<void *>(liveObjects))); assert(!pthread_setspecific(key, reinterpret_cast<void *>(liveObjects)));
} }
auto idx = staticsOpaque - staticsBegin; auto idx = this - staticsBegin;
auto &entry = liveObjects[idx]; auto &entry = liveObjects[idx];
if (!entry.ptr) if (!entry.ptr)
{ {
@ -75,5 +58,4 @@ void *ThreadLocalGet(void *opaque)
} }
return entry.ptr; return entry.ptr;
} }
#endif #endif

View File

@ -4,39 +4,62 @@
#ifdef __MINGW32__ #ifdef __MINGW32__
# include <cstddef> # include <cstddef>
template<class Type> class ThreadLocalCommon
class ThreadLocal
{ {
static void Ctor(Type *type) ThreadLocalCommon(const ThreadLocalCommon &other) = delete;
ThreadLocalCommon &operator =(const ThreadLocalCommon &other) = delete;
protected:
size_t size;
void (*ctor)(void *);
void (*dtor)(void *);
size_t padding;
void *Get() const;
public:
ThreadLocalCommon() = default;
static constexpr size_t Alignment = 0x20;
};
// * If this fails, add or remove padding fields, possibly change Alignment to a larger power of 2.
static_assert(sizeof(ThreadLocalCommon) == ThreadLocalCommon::Alignment, "fix me");
template<class Type>
class ThreadLocal : public ThreadLocalCommon
{
static void Ctor(void *type)
{ {
new(type) Type(); new(type) Type();
} }
static void Dtor(Type *type) static void Dtor(void *type)
{ {
type->~Type(); reinterpret_cast<Type *>(type)->~Type();
} }
size_t size = sizeof(Type);
void (*ctor)(Type *) = Ctor;
void (*dtor)(Type *) = Dtor;
size_t padding;
public: public:
Type *operator &() ThreadLocal()
{ {
static_assert(sizeof(ThreadLocal<Type>) == 0x20, "fix me"); // * If this fails, you're out of luck.
void *ThreadLocalGet(void *opaque); static_assert(sizeof(ThreadLocal<Type>) == sizeof(ThreadLocalCommon), "fix me");
return reinterpret_cast<Type *>(ThreadLocalGet(reinterpret_cast<void *>(this))); size = sizeof(Type);
ctor = Ctor;
dtor = Dtor;
} }
operator Type &() Type *operator &() const
{
return reinterpret_cast<Type *>(Get());
}
operator Type &() const
{ {
return *(this->operator &()); return *(this->operator &());
} }
}; };
# define THREAD_LOCAL(Type, tl) ThreadLocal<Type> tl __attribute__((section("tpt_tls"))) __attribute__((aligned(0x20))) # define THREAD_LOCAL(Type, tl) const ThreadLocal<Type> tl __attribute__((section("tpt_tls"), aligned(ThreadLocalCommon::Alignment)))
#else #else
# define THREAD_LOCAL(Type, tl) thread_local Type tl # define THREAD_LOCAL(Type, tl) thread_local Type tl
#endif #endif