生产者消费者模式在多batch推理下的应用(延时队列)
需求
本文代码框架参考自 trtpy: https://zhuanlan.zhihu.com/p/462980738
创建一种生产者消费者模式,实现多线程资源处理,加速模型推理。
需要满足一下几点:
- 生产者生产提交资源,消费者负责处理,并将处理结果返回给生产者。
- 资源队列不是有限的,因为显存是有限的。
- 为了多batch加速,允许生产者生产资源间隔在一定范围内,比如50ms内的资源都可以放到一个batch当给消费者处理
- 保证资源哪里分配哪里释放,哪里使用,这样可以使得程序足够简单。
- 接口模式保证RAII,不暴露内部接口。
设计分析
- 生产者消费者经典框架-需要一个线程安全的队列来储存资源,需要使用
promise
和future
来完成对结果的返回。
- 需要对队列的大小加以控制,并且暴露控制大小的接口
- 需要设置一种延时机制,来控制消费者线程在收集资源时可以等待一段时间用于判断后续还有没有数据传入,此时可以使用
condition_variable
的wait_for
函数。
- 资源的分配和释放都放到消费者线程中,中间结果的消息传递可以采用使用
promise
和future
。
- 在C++中,一般采用抽象类来代替接口,再设置一个构造方法通过多态完成封装。
代码模拟实现
- 文件
Infer.hpp
, 接口声明文件
| #ifndef INFER_HPP
#define INFER_HPP
#include <memory>
#include <string>
#include <map>
#include <future>
namespace Infer {
using Tensor = std::string;
using RET_TYPE = std::map<std::string, Tensor>;
class InferInterface {
public:
// pic_tensor 表示输入给生产者的tensor, timeout 表示超时时间
// wait_timeout 表示当队列满了的情况下,愿意最多等待 wait_timeout ms 来等待
virtual std::shared_future<RET_TYPE> forward(const Tensor& pic_tensor, int wait_timeout = 10) = 0;
virtual void set_timeout(const int timeout) = 0;
virtual void set_max_queue_num(const int max_queue_num) = 0;
virtual ~InferInterface() {}
explicit InferInterface() {}
protected:
InferInterface(const InferInterface&) = delete;
InferInterface(InferInterface&&) = delete;
InferInterface& operator=(const InferInterface&) = delete;
};
std::shared_ptr<InferInterface> create_infer(const std::string& filepath, int max_bactch_size);
}
#endif // INFER_HPP
|
- 文件
Infer.cpp
,分析需求中的具体实现
| #include <iostream>
#include <thread>
#include <future>
#include <queue>
#include <map>
#include "infer.hpp"
namespace Infer {
using std::string;
using std::thread;
using std::promise;
using std::future;
using std::map;
using std::queue;
using std::atomic;
using std::mutex;
class InferImpl: public InferInterface {
protected:
// 定义队列中单个任务结果
using RET_RPOMISE_PTR = std::shared_ptr<promise<RET_TYPE>>;
struct Job {
public:
// 输入图像矩阵
Tensor pic;
// 对输入图像矩阵处理后返回的结果
RET_RPOMISE_PTR ret_promise;
};
// 这里假设是打开文件后的上下文
string context_;
thread thread_;
queue<Job> queue_;
atomic<bool> running_;
mutex lock_;
atomic<int> max_batchsize_;
atomic<int> timeout_;
atomic<int> max_queue_num_;
std::condition_variable cond_var_;
std::condition_variable cond_queue_overflow_;
public:
explicit InferImpl() {
running_ = false;
max_batchsize_ = 1;
timeout_ = 0;
max_queue_num_ = 5;
}
bool load_model(const string& filepath) {
promise<bool> init_promise;
thread_ = thread(&InferImpl::worker, this, filepath, std::ref(init_promise));
running_ = true;
return init_promise.get_future().get();
}
inline int max_batchsize() const {return max_batchsize_;}
bool set_max_batchsize(const int max_batchsize) {
if (max_batchsize < 1)
return false;
this->max_batchsize_ = max_batchsize;
return true;
}
void set_timeout(const int timeout) {
timeout_ = timeout;
}
void set_max_queue_num(const int max_queue_num) {
max_queue_num_ = max_queue_num;
}
virtual std::shared_future<RET_TYPE> forward(const Tensor& pic_tensor, int wait_timeout) override {
// 给队列提交输入并将处理结果返回
Job job;
job.pic = pic_tensor;
job.ret_promise = RET_RPOMISE_PTR(new promise<RET_TYPE>());
{
std::unique_lock<mutex> l(lock_);
if (queue_.size() >= max_queue_num_) {
if (0 == wait_timeout) {
throw std::runtime_error("exhausted resource");
}
cond_queue_overflow_.wait_for(l, std::chrono::milliseconds(wait_timeout), [&]() {
return queue_.size() < max_queue_num_;
});
}
if (queue_.size() >= max_queue_num_)
throw std::runtime_error("exhausted resource");
// 此处不可以无限提交
queue_.push(job);
}
cond_var_.notify_one();
return job.ret_promise->get_future();
}
~InferImpl() override {
running_ = false;
cond_var_.notify_one();
if (thread_.joinable())
thread_.join();
}
protected:
void worker(const string& filepath, promise<bool>& init_promise) {
// 初始化模型上下文
context_ = filepath;
// 假设此处为模型加载失败
if (context_.empty()) {
context_ = filepath;
init_promise.set_value(false);
return;
}
// 若干初始化上下文代码
// 上下文初始化成功
init_promise.set_value(true);
// 检查资源队列中是否有任务
std::vector<Job> jobs;
// 此处用于模拟
int batch_id = 0;
while (running_) {
{
std::unique_lock<mutex> l(lock_);
cond_var_.wait(l, [&]() {
// return true 表示 退出等待
return !queue_.empty() || !running_;
});
if (!running_)
break;
// 表示当前线程等待的最长时间
if (0 != timeout_) {
cond_var_.wait_for(l, std::chrono::milliseconds(timeout_), [&]() {
return queue_.size() >= max_batchsize_;
});
}
for (int i = 0; !queue_.empty() && i < max_batchsize_; ++i) {
jobs.emplace_back(queue_.front());
queue_.pop();
}
}
// job inference 此处初始jobs所有数据 这里模拟一下
int size = jobs.size();
for (auto& job : jobs) {
auto bbox = job.pic + "_handled";
RET_TYPE handled_result;
handled_result["bbox"] = bbox;
job.ret_promise->set_value(handled_result);
}
jobs.clear();
std::this_thread::sleep_for(std::chrono::milliseconds(100 * size));
std::printf("batch_id = %d, job size=%d \n", batch_id, size);
// 通知生产者 消费者这边已消费一次
cond_queue_overflow_.notify_one();
++batch_id;
}
// 销毁申请的资源
context_.clear();
std::cout << "context_ has cleared" << std::endl;
std::cout << "Workder done!" << std::endl;
}
};
std::shared_ptr<InferInterface> create_infer(const std::string& filepath, int max_bactch_size) {
auto infer_ptr = new InferImpl();
infer_ptr->set_max_batchsize(max_bactch_size);
if (!infer_ptr->load_model(filepath)) {
delete infer_ptr;
return nullptr;
}
return std::shared_ptr<InferInterface>(infer_ptr);
}
} // namespace Infer END
|
- 文件
main.cpp
测试功能
| #include <iostream>
#include <vector>
#include "infer.hpp"
using std::cout;
using std::endl;
using std::cerr;
using RET_FUTRUE = std::shared_future<Infer::RET_TYPE>;
int main() {
std::shared_ptr<Infer::InferInterface> infer = Infer::create_infer("/xxx/some.engine", 5);
if (infer == nullptr) {
cerr << "create infer engine error" << endl;
return -1;
}
// 允许最多等待5ms延迟收集一组数据
infer->set_timeout(5);
// 设置队列中元素的最大个数
infer->set_max_queue_num(10);
cout << "create infer engine success" << endl;
std::vector<RET_FUTRUE> shared_ptrs;
char buffer[100];
for (int i = 0; i < 24; ++i) {
sprintf(buffer, "%d.tensor", i);
// 1000 表示在队列满的情况下 生产者愿意等待1000ms,超时会抛出异常
shared_ptrs.push_back(infer->forward(buffer, 1000));
}
for (auto& shared_ptr : shared_ptrs) {
shared_ptr.get();
}
return 0;
}
|
总结
一直想着如何优雅的来去利用多线程加速模型的推理,trtpy给了我答案。回过头在看自己曾经的实现,发现一个好的接口设计比对接口的具体实现更为重要。
最后更新:
September 17, 2024
创建日期:
September 17, 2024