mercredi 29 août 2018

Using windows fiber in a simple way but unexplainable bugs occur

I played around with windows fibers implementing my own task scheduler when some odd crashes and undefined behaviors occurred. For the sake of simplicity I started a new project and wrote a simple program who performs the following operations:

  1. The main thread creates a bunch of fibers, then launch two threads
  2. The main thread waits until you kill the program
  3. Each worker thread converts himself into a fiber
  4. Each worker thread tries to find a free fiber, then switchs to this new free fiber
  5. Once a thread had switch to a new fiber, it pushes its previous fiber into the free fibers container
  6. Each worker thread goes to the step 4

If you are not familiar with fiber concept this talk is a good start.

The Data

Each thread has its own ThreadData data structure to store its previous, current fiber instances, and its thread index. I tried several way to retrieve the ThreadData data structure during execution:

  • I used thread local storage to store ThreadData pointer
  • I used a container which associate a thread_id with a ThreadData structure

The Problem

When a fiber is entered for the first time (look at the FiberFunc function), the thread using this fiber must pushes its previous fiber into the free fibers container. But it happens that sometimes the previous fiber is null, which is impossible. It is impossible because before switching to a new fiber the thread sets its previous fiber value with its current fiber value (and it sets its current fiber value with the new fiber value).

So if a thread enters in a brand new fiber with its previous fiber set as null, it would mean it comes from nowhere (which doesn't make any sense).

The only reasons a ThreadData has its previous fiber value set as null when it enters to a brand new fiber is that another thread sets it to null or that compiler reordered instructions under the hood.

I checked the assembly and it seems that the compiler is not responsible.

There are several bugs I can't explain:

  1. If I use the first GetThreadData() function to retrieve the ThreadData structure, I can retrieve an instance whose index is different from the thread local index (those indices have been set when threads started). This will make the program assert ( assert(threadData->index == localThreadIndex)).

  2. If I use any other function to retrieve the ThreadData structure I will assert in the FiberFunc function because the previous fiber value is null (assert(threadData->previousFiber)).

Do you have any idea why this code doesn't work ? I spent countless hours trying to figure out what is wrong but I don't see my mistakes.

The Code

I compiled and ran the code with Visual studio 2015 using VC++ compiler, in Release x64. You may try to run it several times before the assert fires.

#include "Windows.h"
#include <vector>
#include <thread>
#include <mutex>
#include <cassert>
#include <iostream>
#include <atomic>

struct Fiber
{
    void* handle;
};

struct ThreadData
{
    Fiber*  previousFiber{ nullptr };
    Fiber*  currentFiber{ nullptr };
    Fiber   fiber{ };
    unsigned int index{};
};

//Threads
std::vector<std::thread*> threads{};
std::vector<std::pair<std::thread::id, unsigned int>> threadsinfo{};

//threads data container
ThreadData  threadsData[8];

//Fibers
std::mutex  fibersLock{};
std::vector<Fiber> fibers{};
std::vector<Fiber*> freeFibers{};

thread_local unsigned int localThreadIndex{};
thread_local Fiber* debug_localTheadLastFiber{};
thread_local ThreadData* localThreadData{};

//This is the first way to retrieve the current thread's ThreadData structure using thread_id
//ThreadData* GetThreadData()
//{
//  std::thread::id threadId( std::this_thread::get_id());
//  for (auto const& pair : threadsinfo)
//  {
//      if (pair.first == threadId)
//      {
//          return &threadsData[pair.second];
//      }
//  }
//
//  //It is not possible to assert
//  assert(false);
//  return nullptr;
//}

//This is the second way to retrieve the current thread's ThreadData structure using thread local storage
//ThreadData* GetThreadData()
//{
//  return &threadsData[localThreadIndex];
//}


//This is the third way to retrieve the current thread's ThreadData structure using thread local storage
ThreadData* GetThreadData()
{
    return localThreadData;
}


//Try to pop a free fiber from the container, thread safe due to mutex usage
bool  TryPopFreeFiber(Fiber*& fiber)
{
    std::lock_guard<std::mutex> guard(fibersLock);
    if (freeFibers.empty()) { return false; }
    fiber = freeFibers.back();
    assert(fiber);
    assert(fiber->handle);
    freeFibers.pop_back();
    return true;
}


//Try to push a free fiber to the container, thread safe due to mutex usage
bool PushFreeFiber(Fiber* fiber)
{
    std::lock_guard<std::mutex> guard(fibersLock);
    freeFibers.push_back(fiber);
    return true;
}


//the __declspec(noinline) is used to inspect code in release mode, comment it if you want
__declspec(noinline) void  SwitchToFiber(Fiber* newFiber)
{
    //You want to switch to another fiber
    //You first have to save your current fiber instance to release it once you will be in the new fiber
    {
        ThreadData* threadData{ GetThreadData() };
        assert(threadData->index == localThreadIndex);
        assert(threadData->currentFiber);
        threadData->previousFiber = threadData->currentFiber;
        threadData->currentFiber = newFiber;
        debug_localTheadLastFiber = threadData->previousFiber;
        assert(threadData->previousFiber);
        assert(newFiber);
        assert(newFiber->handle);
    }

    //You switch to the new fiber
    //this call will either make you enter in the FiberFunc function if the fiber has never been used
    //Or you will continue to execute this function if the new fiber has been already used (not that you will have a different stack so you can't use the old threadData value)
    ::SwitchToFiber(newFiber->handle);

    {
        //You must get the current ThreadData* again, because you come from another fiber (the previous statement is a switch), this fiber could have been used by any other thread
        ThreadData* threadData{ GetThreadData() };

        //THIS ASSERT WILL FIRES IF YOU USE THE FIRST GetThreadData METHOD, WHICH IS IMPOSSIBLE....
        assert(threadData->index == localThreadIndex);

        assert(threadData);
        assert(threadData->previousFiber);

        //We release the previous fiber
        PushFreeFiber(threadData->previousFiber);
        debug_localTheadLastFiber = nullptr;
        threadData->previousFiber = nullptr;
    }

}


void ExecuteThreadBody()
{
    Fiber*  newFiber{};

    if (TryPopFreeFiber(newFiber))
    {
        SwitchToFiber(newFiber);
    }
}


void ThreadFunc(unsigned int index)
{
    threadsinfo[index] = std::make_pair(std::this_thread::get_id(), index);

    //setting up the current thread data
    ThreadData* threadData{ &threadsData[index] };
    threadData->index = index;
    threadData->fiber = Fiber{ ConvertThreadToFiber(nullptr) };
    threadData->currentFiber = &threadData->fiber;

    localThreadData = threadData;
    localThreadIndex = index;

    while (true)
    {
        ExecuteThreadBody();
    }
}


//The entry point of all fibers
void __stdcall FiberFunc(void* data)
{
    //You enter to the fiber for the first time

    ThreadData* threadData{ GetThreadData() };

    //Making sure that the thread data structure is the good one
    assert(threadData->index == localThreadIndex);

    //Here you will assert
    assert(threadData->previousFiber);

    PushFreeFiber(threadData->previousFiber);
    threadData->previousFiber = nullptr;

    while (true)
    {
        ExecuteThreadBody();
    }
}


__declspec(noinline) void main()
{
    constexpr unsigned int threadCount{ 2 };
    constexpr unsigned int fiberCount{ 20 };

    threadsinfo.resize(threadCount);

    fibers.resize(fiberCount);
    for (auto index = 0; index < fiberCount; ++index)
    {
        fibers[index] = { CreateFiber(0, FiberFunc, nullptr) };
    }

    freeFibers.resize(fiberCount);
    for (auto index = 0; index < fiberCount; ++index)
    {
        freeFibers[index] = std::addressof(fibers[index]);
    }

    threads.resize(threadCount);

    for (auto index = 0; index < threadCount; ++index)
    {
        threads[index] = new std::thread{ ThreadFunc, index };
    }

    while (true);

    //I know, it is not clean, it will leak
}

Aucun commentaire:

Enregistrer un commentaire