使用cpp
实现一个接口友好, 并发安全的线程池.
线程池的作用
当我们需要频繁创建线程用于计算密集型任务时, 创建和释放需要频繁调用内核API进行线程的创建、切换和资源回收等, 会造成服务的处理性能下降. 线程池就是预先向操作系统申请一些工作线程, 等待用户将(计算)任务放入线程池交由线程池中的线程运行, 提升服务处理高并发任务的能力.
线程池的原理
当有任务放入线程池时, 线程池会唤醒工作线程, 工作线程会拿走任务进行计算. 而当线程池工作队列满时, 放入任务会阻塞直到工作队列非满.
由以上逻辑分析可见, 线程池就是典型的生产者-消费者问题.
通常一个线程池应该有以下参数:
- 工作线程个数
- 工作队列大小
- 主线程
其中, 消费者就是工作线程, 而生产者是用户线程.
逻辑实现
生产者-消费者问题中, 缓冲区就是临界资源. 而生产者和消费者需要互斥访问临界资源, 这通过互斥锁mutex
实现. 而线程的唤醒(如生产者放入缓冲区唤醒消费者)可以通过条件变量或者信号量(cpp20
中semaphore
库)来实现.
线程池中消费者(工作线程)和工作队列是线程池应该管理的内容. 任务的添加取决于用户, 线程池需要提供添加任务和结束线程池的接口给用户. 同时由于结束线程池的时候工作队列可能还有计算任务, 因此需要等待所有计算任务运行完毕后才能结束线程池, 释放线程资源.
代码实现
线程池内部参数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| std::vector<std::thread> workers_;
std::deque<Task> tasks_;
size_t maxQueueSize_;
size_t numOfThreads_;
bool running_;
std::mutex mtx_; std::condition_variable notFull_, notEmpty_;
|
线程池构造函数
注意线程池不可拷贝赋值以及不可拷贝构造, 而且需要显示构造.
1 2
| explicit ThreadPool(size_t numOfThreads = 0, size_t maxQueueSize = 20) : numOfThreads_(numOfThreads), maxQueueSize_(maxQueueSize), running_(false) {}
|
开启线程池
1 2 3 4 5 6 7
| void ThreadPool::start () { running_ = true; for (int i = 0; i < numOfThreads_; i ++ ) workers_.emplace_back(std::thread(&ThreadPool::consumeTask, this)); }
|
添加任务接口
为了提供用户友好的接口, 我使用了模板函数和可变参数模板来实现添加任务的接口(参考cpp标准库的实现).
1 2 3 4 5 6 7 8 9
|
template<typename _Callable, typename... _Args> void addTask(_Callable&& __f, _Args&&... __args) { Task task = std::bind(std::forward<_Callable>(__f),std::forward<_Args...>(__args...)); addToWorkingPool(task); }
|
其中线程池内部将任务定义为:
1
| using Task = std::function<void()>;
|
添加任务函数的实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| void ThreadPool::addToWorkingPool(Task task) { if (!running_) throw std::runtime_error ("ThreadPool must running when adding tasks.");
if (numOfThreads_ == 0) { task(); return ; }
std::unique_lock<std::mutex> lock(mtx_); while (isFull() and running_) notFull_.wait(lock);
if (!running_) return ;
tasks_.emplace_back(move(task)); notEmpty_.notify_one(); }
|
添加任务到线程池时, 如果线程池正在运行且工作队列已满, 则该生产者阻塞等待线程池非满notFull
条件变量. 同时为了防止虚假唤醒, 使用while
进行该条件的判断.
虚假唤醒: 比如工作队列为空且生产者添加了一个任务并唤醒了当前工作线程, 而这个任务被另一个没有阻塞的线程执行了, 那么当当前线程拿到互斥锁之后, 如果不判断就直接取任务, 会存在任务队列为空的可能(消费者的解释,生产者同理, 即可能被其他生产者放入任务队列了, 此时再放就会逻辑出错.)
当线程池在运行时, 添加任务到工作队列即可.
工作线程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| void ThreadPool::consumeTask() { while (true) { Task task; { std::unique_lock<std::mutex> lock(mtx_); while (tasks_.empty() and running_) notEmpty_.wait(lock);
if (!running_ and tasks_.empty()) return ;
assert(tasks_.size()); task = tasks_.front(); tasks_.pop_front(); notFull_.notify_one(); } task(); } }
|
工作进程是消费者, 需要不断从工作队列中互斥获取任务并执行.
当线程池在运行且工作队列为空时, 工作线程需要阻塞等待工作队列非空的条件变量(notEmpty
). 注意使用while
判断从而避免虚假唤醒的问题.
当线程池停止运行且工作队列为空时, 此时计算任务全被执行完毕且不会有新的任务添加入线程池, 因此结束工作线程. 而如果线程池停止运行且工作队列不为空时, 此时任然需要执行计算任务.
计算任务的执行在unique_lock
的作用域之外(使用RAII机制的资源管理类, 其离开作用域时析构函数会释放持有的锁资源), 否则会严重影响锁的释放和获取, 导致性能严重下降.
结束线程池
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| void ThreadPool::stop() { if (!running_) throw std::runtime_error ("ThreadPool must running when stop it.");
{ std::unique_lock<std::mutex> lock(mtx_); running_ = false; notEmpty_.notify_all(); notFull_.notify_all(); }
for (auto& t : workers_) t.join(); }
|
结束线程池时全部唤醒阻塞在notFull
上的生产者和notEmpty
上的消费者线程. 然后回收线程资源.
析构线程池
1 2 3 4 5
| ~ThreadPool() { if (running_) stop(); }
|
如果还在运行, 调用stop
函数结束线程池即可.
性能评估
比较单线程计算和多生产者线程池模型:
1 2 3 4 5
| oneThreadRun cost time = 79572 ms
ThreadPool cost time = 10155 ms
Process finished with exit code 0
|
遇到的问题和解决方式
- 性能低的问题
将计算任务Task
的执行放在释放锁之后.
- 编写多生产者模型时的BUG
最初使用匿名函数实现多生产者模式是这样写的:
1 2 3 4 5 6 7 8 9 10 11 12
| for (int i = 0; i < producers; i ++ ) { produceThread[i] = thread([&] () -> void { for (int j = 0; j < M; j ++ ) { futu[i][j] = prom[i][j].get_future(); threadPool.addTask(func, ref(prom[i][j])); } finish.fetch_add(1); }); }
|
写完死活调试都会出现BUG(段错误), 在windows平台和linux平台(WSL2)均会出现问题. 最后重写多生产者模式(和单生产者类似)才解决这个问题. 问题就是使用thread
并传入匿名函数
时, 不能通过引用去捕获参数, 而应该通过std::ref
包装传递. 浪费了我大半天…
1 2 3 4 5 6 7 8 9 10 11 12
| for (int i = 0; i < producers; i ++ ) { produceThread[i] = thread([] (ThreadPool& threadPool, int i) -> void { for (int j = 0; j < M; j ++ ) { futu[i][j] = prom[i][j].get_future(); threadPool.addTask(func, ref(prom[i][j])); } finish.fetch_add(1); }, ref(threadPool), i); }
|
- 多生产者的设计问题
多生产者应该由用户决定和资源回收.
代码
GitHubCode