Make thread_local replacement for MinGW slightly nicer

This commit is contained in:
Tamás Bálint Misius 2021-10-23 09:38:58 +02:00
parent 0ed8d0a0be
commit 0f2eedd4fb
No known key found for this signature in database
GPG Key ID: 5B472A12F6ECA9F2
2 changed files with 73 additions and 68 deletions

View File

@ -5,58 +5,41 @@
# include <cstdlib>
# include <cassert>
static pthread_once_t once = PTHREAD_ONCE_INIT;
static pthread_key_t key;
struct ThreadLocalCommon
void *ThreadLocalCommon::Get() const
{
size_t size;
void (*ctor)(void *);
void (*dtor)(void *);
size_t padding;
};
static_assert(sizeof(ThreadLocalCommon) == 0x20, "fix me");
// https://stackoverflow.com/questions/16552710/how-do-you-get-the-start-and-end-addresses-of-a-custom-elf-section
extern ThreadLocalCommon __start_tpt_tls;
extern ThreadLocalCommon __stop_tpt_tls;
static pthread_once_t once = PTHREAD_ONCE_INIT;
static pthread_key_t key;
struct ThreadLocalEntry
{
void *ptr;
};
// https://stackoverflow.com/questions/16552710/how-do-you-get-the-start-and-end-addresses-of-a-custom-elf-section
extern ThreadLocalCommon __start_tpt_tls;
extern ThreadLocalCommon __stop_tpt_tls;
static void ThreadLocalDestroy(void *opaque)
{
auto *staticsBegin = &__start_tpt_tls;
auto *staticsEnd = &__stop_tpt_tls;
auto staticsCount = staticsEnd - staticsBegin;
auto *liveObjects = reinterpret_cast<ThreadLocalEntry *>(opaque);
if (liveObjects)
struct ThreadLocalEntry
{
for (auto i = 0; i < staticsCount; ++i)
{
if (liveObjects[i].ptr)
{
staticsBegin[i].dtor(liveObjects[i].ptr);
free(liveObjects[i].ptr);
}
}
free(liveObjects);
}
}
void *ptr;
};
static void ThreadLocalCreate()
{
assert(!pthread_key_create(&key, ThreadLocalDestroy));
}
void *ThreadLocalGet(void *opaque)
{
auto *staticsBegin = &__start_tpt_tls;
auto *staticsEnd = &__stop_tpt_tls;
auto *staticsOpaque = reinterpret_cast<ThreadLocalCommon *>(opaque);
pthread_once(&once, ThreadLocalCreate);
pthread_once(&once, []() -> void {
assert(!pthread_key_create(&key, [](void *opaque) -> void {
auto *staticsBegin = &__start_tpt_tls;
auto *staticsEnd = &__stop_tpt_tls;
auto staticsCount = staticsEnd - staticsBegin;
auto *liveObjects = reinterpret_cast<ThreadLocalEntry *>(opaque);
if (liveObjects)
{
for (auto i = 0; i < staticsCount; ++i)
{
if (liveObjects[i].ptr)
{
staticsBegin[i].dtor(liveObjects[i].ptr);
free(liveObjects[i].ptr);
}
}
free(liveObjects);
}
}));
});
auto *liveObjects = reinterpret_cast<ThreadLocalEntry *>(pthread_getspecific(key));
if (!liveObjects)
{
@ -65,7 +48,7 @@ void *ThreadLocalGet(void *opaque)
assert(liveObjects);
assert(!pthread_setspecific(key, reinterpret_cast<void *>(liveObjects)));
}
auto idx = staticsOpaque - staticsBegin;
auto idx = this - staticsBegin;
auto &entry = liveObjects[idx];
if (!entry.ptr)
{
@ -75,5 +58,4 @@ void *ThreadLocalGet(void *opaque)
}
return entry.ptr;
}
#endif

View File

@ -4,39 +4,62 @@
#ifdef __MINGW32__
# include <cstddef>
template<class Type>
class ThreadLocal
class ThreadLocalCommon
{
static void Ctor(Type *type)
{
new (type) Type();
}
ThreadLocalCommon(const ThreadLocalCommon &other) = delete;
ThreadLocalCommon &operator =(const ThreadLocalCommon &other) = delete;
static void Dtor(Type *type)
{
type->~Type();
}
size_t size = sizeof(Type);
void (*ctor)(Type *) = Ctor;
void (*dtor)(Type *) = Dtor;
protected:
size_t size;
void (*ctor)(void *);
void (*dtor)(void *);
size_t padding;
void *Get() const;
public:
Type *operator &()
ThreadLocalCommon() = default;
static constexpr size_t Alignment = 0x20;
};
// * If this fails, add or remove padding fields, possibly change Alignment to a larger power of 2.
static_assert(sizeof(ThreadLocalCommon) == ThreadLocalCommon::Alignment, "fix me");
template<class Type>
class ThreadLocal : public ThreadLocalCommon
{
static void Ctor(void *type)
{
static_assert(sizeof(ThreadLocal<Type>) == 0x20, "fix me");
void *ThreadLocalGet(void *opaque);
return reinterpret_cast<Type *>(ThreadLocalGet(reinterpret_cast<void *>(this)));
new(type) Type();
}
operator Type &()
static void Dtor(void *type)
{
reinterpret_cast<Type *>(type)->~Type();
}
public:
ThreadLocal()
{
// * If this fails, you're out of luck.
static_assert(sizeof(ThreadLocal<Type>) == sizeof(ThreadLocalCommon), "fix me");
size = sizeof(Type);
ctor = Ctor;
dtor = Dtor;
}
Type *operator &() const
{
return reinterpret_cast<Type *>(Get());
}
operator Type &() const
{
return *(this->operator &());
}
};
# define THREAD_LOCAL(Type, tl) ThreadLocal<Type> tl __attribute__((section("tpt_tls"))) __attribute__((aligned(0x20)))
# define THREAD_LOCAL(Type, tl) const ThreadLocal<Type> tl __attribute__((section("tpt_tls"), aligned(ThreadLocalCommon::Alignment)))
#else
# define THREAD_LOCAL(Type, tl) thread_local Type tl
#endif