请教大家一个 C++线程池的问题

wisefree · 2024-6-17 20:25:57 · 106 次点击
最近在找一个简单的 C++11 线程池实现,发现网上有很多相关的代码,在 CSDN 网上看到一个比较简洁的。但是总感觉是不是实现错了。
1. Any 类 noncopyable 的,仅仅支持移动语义,
2. Result 类使用了 Any 实例作为成员变量,那么 Result 类应该也是 noncopyable 的,
3. `Result SubmitTask(std::shared_ptr<Task> taskPtr);`直接使用了复制语义,应该是有问题吧,可是代码能够被 vs2022 正常编译。


threadpool.h

``` C++
#pragma once
#include <vector>
#include <cstdint>
#include <queue>
#include <memory>
#include <atomic>
#include <mutex>
#include <thread>
#include <condition_variable>
#include <functional>
#include <sstream>
#include <unordered_map>


// Any 类型:可以接收任意数据的类型
// 任意其他类型 template
// 能让一个类型指向其他类型,基类指针可以指向子类
class Any
{
public:
        Any() = default;
        ~Any() = default;
        Any(const Any&) = delete;
        Any& operator=(const Any&) = delete;
        Any(Any&&) = default;
        Any& operator=(Any&&) = default;

        template<typename T>
        Any(T data) : m_base(std::make_unique<Derive<T>>(data)) {}

        template<typename T>
        T cast_()
        {
                Derive<T>* pd = dynamic_cast<Derive<T>*>(m_base.get());

                if (pd == nullptr) {
                        throw "type is unmath!!";
                }

                return pd->m_data;
        }

private:
        // 基类
        class Base
        {
        public:
                virtual ~Base() = default;
        };

        // 派生类
        template<typename T>
        class Derive : public Base
        {
        public:
                Derive(T data) : m_data(data) {}
        public:
                T m_data;
        };

private:
        std::unique_ptr<Base> m_base;
};


// 实现一个信号量类
class Semaphore
{
public:
        Semaphore(int limit = 0) : m_resLimit(limit)
        {}

        ~Semaphore() = default;

        // 获取一个信号量资源
        void wait()
        {
                std::unique_lock<std::mutex> lock(m_mtx);
                // 如果没有资源,阻塞线程
                while (m_resLimit < 1) {
                        m_cond.wait(lock);
                }

                m_resLimit--;
        }

        // 增加一个信号量资源
        void post()
        {
                std::unique_lock<std::mutex> lock(m_mtx);
                m_resLimit++;
                m_cond.notify_all();

        }
private:
        int m_resLimit;  // 资源量
        std::mutex m_mtx;
        std::condition_variable m_cond;
};


// Task 类型前置声明
class Task;

// 实现接收提交到线程池的 task 任务执行完成后的返回值类型
class Result
{
public:
        Result(std::shared_ptr<Task> task, bool isValid = true);
        ~Result() = default;

        // setVal
        void setVal(Any result);

        // get 方法,用户调用这个方法获取 task 的返回值
        Any get();
private:
        Any m_any;
        Semaphore m_sem;
        std::shared_ptr<Task> m_task;
        std::atomic_bool m_isValid;
};


// 任务抽象基类
class Task
{
public:
        void exec();
        void setResult(Result* res);
        virtual Any run() = 0;

private:
        Result* m_result{ nullptr };  // 不要用智能指针,task 含有 Result  Result 含有 task ,可能导致问题
};

class MyTask : public Task
{
public:
        MyTask(int start, int end) : m_start(start), m_end(end) {}

        Any run()
        {
                std::ostringstream ostr;
                ostr << std::this_thread::get_id();
                printf("thead %s, task start \n", ostr.str().c_str());

                uint64_t sum = 0;

                for (int i = m_start; i <= m_end; i++) {
                        sum += i;
                }

                printf("sum %llu\n", sum);
                std::this_thread::sleep_for(std::chrono::seconds(2));
                printf("thread %s, task finish \n", ostr.str().c_str());

                return sum;
        }

private:
        int m_start;
        int m_end;
};

enum ThreadPoolMode
{
        MODE_FIXED,  // 固定数量的线程
        MODE_CACHED,  // 线程数量可以动态增长
};

class Thread
{
public:
        using ThreadFunc = std::function<void(int)>;

        Thread(ThreadFunc func);
        ~Thread();

        void Start();
        int GetId() { return m_threadId; }
private:
        ThreadFunc m_func;
        static int generateId;
        int m_threadId;
};


class ThreadPool
{
public:
        ThreadPool();
        ~ThreadPool();

        // 设置线程池工作模式
        void SetMode(ThreadPoolMode mode);

        // 设置任务数量上限
        void SetTaskQueMaxThreshold(int value);

        // 给线程池提交任务
        Result SubmitTask(std::shared_ptr<Task> taskPtr);

        // 开启线程池
        void Start(int initThreadSize = std::thread::hardware_concurrency());

private:
        ThreadPool(const ThreadPool&) = delete;
        ThreadPool& operator=(const ThreadPool&) = delete;

        // 定义线程函数
        void ThreadFunc(int threadId);
        bool CheckRunningState() const;

private:
        std::unordered_map<int, std::unique_ptr<Thread>> m_threadMap;  // 线程列表
        int m_initThreadSize;  // 初始的线程数量
        std::atomic_int m_curThreadSize;  // 当前线程数量

        std::queue<std::shared_ptr<Task>> m_taskQue;  // 任务队列
        std::atomic_int m_taskSize;  // 任务的数量
        int m_taskQueMaxThreshold;  // 任务队列的数量上限

        std::mutex m_taskQueMtx;  // 保证任务队列的线程安全
        std::condition_variable m_taskQueNotFullCv;  // 表示任务队列不满
        std::condition_variable m_taskQueNotEmptyCv;  // 表示任务队列不空
        std::condition_variable m_exitCv;  // 退出线程池

        ThreadPoolMode m_poolMode;  // 当前线程池的工作模式
        std::atomic_bool m_isPoolRuning;  // 当前线程工作状态
};
```

threadpool.cpp

``` C++
#include "threadpool.h"
#include <functional>
#include <iostream>

constexpr int TASK_MAX_THRESHOLD = 1024;

ThreadPool::ThreadPool() : m_initThreadSize(4), m_taskSize(0),
m_taskQueMaxThreshold(TASK_MAX_THRESHOLD),
m_poolMode(ThreadPoolMode::MODE_FIXED)
{
}

ThreadPool::~ThreadPool()
{
        m_isPoolRuning = false;
        std::unique_lock<std::mutex> lock(m_taskQueMtx);

        // 线程 要么在阻塞中 要么在工作中
        while (m_threadMap.size() > 0) {
                m_taskQueNotEmptyCv.notify_all();  // 唤醒等待的工作线程
                m_exitCv.wait(lock);
        }
}

void ThreadPool::SetMode(ThreadPoolMode mode)
{
        if (m_isPoolRuning) { return; }  // 线程池启动后,不允许设置线程池一些参数

        m_poolMode = mode;
}

void ThreadPool::SetTaskQueMaxThreshold(int value)
{
        if (m_isPoolRuning) { return; }

        m_taskQueMaxThreshold = value;
}

Result ThreadPool::SubmitTask(std::shared_ptr<Task> taskPtr)
{
        // 获取锁
        std::unique_lock<std::mutex> lock(m_taskQueMtx);

        // 线程通信,检查任务队列是否有空余
        while (m_taskQue.size() >= m_taskQueMaxThreshold) {

                // 用于提交任务,不能阻塞太长时间,如果超过 1s ,给用户返回提交失败
                if (m_taskQueNotFullCv.wait_for(lock, std::chrono::seconds(1)) == std::cv_status::timeout) {
                        return Result(taskPtr, false);
                }
        }

        // 如果有空余,把任务提交到任务队列中
        m_taskQue.emplace(taskPtr);
        m_taskSize++;

        // 因为新放了任务,任务队列肯定不为空了,在 m_taskQueNotEmptyCv 进行通知,赶快分配线程执行这个任务
        m_taskQueNotEmptyCv.notify_all();

        return Result(taskPtr);
}

void ThreadPool::Start(int initThreadSize)
{
        m_initThreadSize = initThreadSize;
        m_curThreadSize = initThreadSize;
    m_isPoolRuning = true;

        // 创建线程对象
        for (int i = 0; i < m_initThreadSize; i++) {
                auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::ThreadFunc, this, std::placeholders::_1));
                int threadId = ptr->GetId();
                m_threadMap.emplace(threadId, std::move(ptr));
        }

        // 启动所有线程
        for (auto iter = m_threadMap.cbegin(); iter != m_threadMap.end(); iter++) {
                iter->second->Start();
        }
}

void ThreadPool::ThreadFunc(int threadId)
{
        while (true) {

                // 获取锁
                std::unique_lock<std::mutex> lock(m_taskQueMtx);

                std::ostringstream ostr;
                ostr << std::this_thread::get_id();
                printf("thead %s, To Get task \n", ostr.str().c_str());

                // 判断任务队列是否为空
                while (m_taskQue.empty()) {
                        if (!m_isPoolRuning) {
                                m_threadMap.erase(threadId);
                                m_exitCv.notify_all();

                                printf("deconstructor thread exit, id = %d\n", threadId);
                                return;
                        }
            
                        m_taskQueNotEmptyCv.wait(lock);

                }

                printf("thead %s, Getted task \n", ostr.str().c_str());
                // 不为空,获取任务
                auto taskPtr = m_taskQue.front();  // front()返回引用,auto 忽略引用属性,正好满足需要
                m_taskQue.pop();
                m_taskSize--;

                lock.unlock();  // 释放锁;

                // 如果任务队列还有任务,通知其他线程执行任务
                if (m_taskQue.size() > 0) {
                        m_taskQueNotEmptyCv.notify_all();
                }

                // 通知队列已经不满
                m_taskQueNotFullCv.notify_all();

                taskPtr->exec();

                if (!m_isPoolRuning) {
                        m_threadMap.erase(threadId);
                        m_exitCv.notify_all();

                        printf("deconstructor thread exit, id = %d\n", threadId);
                        return;
                }

        }
}

bool ThreadPool::CheckRunningState() const
{
        if (m_isPoolRuning) {
                return true;
        }

        return false;
}

// 线程方法
int Thread::generateId = 0;

Thread::Thread(ThreadFunc func) : m_func(func),
                                                                m_threadId(generateId++)
{
}

Thread::~Thread()
{
}

void Thread::Start()
{
        std::thread t(m_func, m_threadId);
        t.detach();
}

Result::Result(std::shared_ptr<Task> task, bool isValid) : m_task(task), m_isValid(isValid)
{
        m_task->setResult(this);
}

void Result::setVal(Any result)
{
        m_any = std::move(result);
        m_sem.post();  // 通知已经获得结果
}

Any Result::get()
{
        if (!m_isValid) {
                return "";
        }

        m_sem.wait();  // 等待结果
        return std::move(m_any);
}


void Task::exec()
{
        if (m_result != nullptr) {
                Any result = run();  // 这里发生多态调用

                m_result->setVal(std::move(result));
        }
}

void Task::setResult(Result* res)
{
        m_result = res;
}

```

main.cpp

``` C++
#include "threadpool.h"

#include <chrono>
#include <iostream>

using std::cout;
using std::endl;


int main(int argc, char* argv[])
{
        {
                ThreadPool pool;
                pool.Start(4);

                Result res1 = pool.SubmitTask(std::make_shared<MyTask>(1, 100000000));
                Result res2 = pool.SubmitTask(std::make_shared<MyTask>(100000001, 200000000));
                Result res3 = pool.SubmitTask(std::make_shared<MyTask>(200000001, 300000000));

                //uint64_t sum1 = res1.get().cast_<uint64_t>();
                //uint64_t sum2 = res2.get().cast_<uint64_t>();
                //uint64_t sum3 = res3.get().cast_<uint64_t>();

                //cout << (sum1 + sum2 + sum3) << endl;
        }

        cout << "main ofer" << endl;

        getchar();
        return 0;
}
```
举报· 106 次点击
登录 注册 站外分享
5 条回复  
donaldturinglee 小成 2024-6-17 21:44:25
你在 Github 按照 star 挑几个高星的看看
ysc3839 小成 2024-6-17 22:02:07
shared_ptr 复制只是增加引用计数吧?底层对象没复制。
zhaoloving 小成 2024-6-17 22:18:22
有右值构造函数,函数返回一个右值就好了
leonshaw 小成 2024-6-17 23:10:25
Any(Any&&) = default;
xyz1001 小成 2024-6-18 08:56:06
返回的 Result 是临时变量,属于将亡值,也是右值的一种,走的右值拷贝构造
返回顶部