C++ 20 coroutine

2025/03/06 coroutine 共 4038 字,约 12 分钟
程序员

协程消息驱动小例子

1. 需求概述

本项目基于 C++20 协程 (std::coroutine) 实现一个 轻量级的消息驱动异步调度协程小例子,主要功能如下:

  • 协程等待消息 (co_await 某个 key),当该消息到达时,协程会自动恢复。
  • 支持乱序消息到达,协程可以按依赖关系正确处理恢复顺序。
  • 任务队列调度机制,避免协程直接 resume() 可能带来的栈过深问题。
  • 无锁单线程模型每个线程独立管理一个 CoroutineContext,不需要加锁,提高效率。
  • 异常捕获机制,协程内部异常会被 std::exception_ptr 捕获,可在外部重新抛出处理。
  • 支持清理历史消息,避免消息数据占用过多内存。

2. 目录结构

/project
│── src
│   ├── coroutine_context.h    // 协程上下文(消息管理、调度)
│   ├── task.h                 // 协程封装 Task(管理 resume/destroy)
│   ├── main.cpp               // 入口文件,示例协程及消息流
│
└── README.md                  // 本文档

3. 编译 & 运行

(1) 编译

需要支持 C++20 的编译器,例如 GCC 11+Clang 14+

g++ -std=c++20 -pthread -o main src/main.cpp

(2) 运行

./main

程序运行后,你可以看到 ABC 协程按正确顺序恢复,即使消息是乱序到达的。


4. 代码介绍

(1) coroutine_context.h:协程上下文

#pragma once

#include <coroutine>
#include <unordered_map>
#include <string>
#include <queue>
#include <functional>
#include <vector>

// 协程上下文:管理消息、等待者、任务队列
class CoroutineContext {
public:
    struct Awaiter {
        CoroutineContext &ctx_;
        std::string msg_key_;

        Awaiter(CoroutineContext &ctx, std::string key) : ctx_(ctx), msg_key_(std::move(key)) {}

        bool await_ready() noexcept { return ctx_.messages_.find(msg_key_) != ctx_.messages_.end(); }
        void await_suspend(std::coroutine_handle<> h) noexcept { ctx_.waiters_[msg_key_].push_back(h); }
        std::string await_resume() noexcept { return ctx_.messages_[msg_key_]; }
    };

    Awaiter wait_for(const std::string &key) { return Awaiter(*this, key); }

    void message_arrive(const std::string &key, const std::string &value) {
        messages_[key] = value;
        if (waiters_.find(key) != waiters_.end()) {
            for (auto h : waiters_[key]) tasks_.push([h]() { h.resume(); });
            waiters_[key].clear();
        }
        run_tasks();
    }

    void run_tasks() {
        while (!tasks_.empty()) {
            auto task = tasks_.front();
            tasks_.pop();
            task();
        }
    }

    void purge_message(const std::string &key) { messages_.erase(key); }

private:
    std::unordered_map<std::string, std::string> messages_;
    std::unordered_map<std::string, std::vector<std::coroutine_handle<>>> waiters_;
    std::queue<std::function<void()>> tasks_;
};

(2) task.h:封装协程 Task

#pragma once

#include <coroutine>
#include <exception>

struct Task {
    struct promise_type {
        std::exception_ptr exception_;

        Task get_return_object() noexcept { return Task{ std::coroutine_handle<promise_type>::from_promise(*this) }; }
        std::suspend_always initial_suspend() noexcept { return {}; }
        std::suspend_always final_suspend() noexcept { return {}; }
        void unhandled_exception() noexcept { exception_ = std::current_exception(); }
        void return_void() noexcept {}
    };

    explicit Task(std::coroutine_handle<promise_type> h) : coro_(h) {}

    void resume() { if (coro_ && !coro_.done()) coro_.resume(); }

    void rethrow_exception_if_any() {
        if (coro_.done() && coro_.promise().exception_) std::rethrow_exception(coro_.promise().exception_);
    }

    ~Task() { if (coro_) coro_.destroy(); }

private:
    std::coroutine_handle<promise_type> coro_;
};

(3) main.cpp:示例协程 & 运行逻辑

#include "coroutine_context.h"
#include "task.h"
#include <iostream>
#include <thread>
#include <vector>
#include <chrono>

Task process_A(CoroutineContext &ctx) {
    auto msgA = co_await ctx.wait_for("A");
    std::cout << "[A] Received: " << msgA << "
";
}

Task process_B(CoroutineContext &ctx) {
    co_await ctx.wait_for("A");
    auto msgB = co_await ctx.wait_for("B");
    std::cout << "[B] Received: " << msgB << "
";
}

Task process_C(CoroutineContext &ctx) {
    co_await ctx.wait_for("B");
    auto msgC = co_await ctx.wait_for("C");
    std::cout << "[C] Received: " << msgC << "
";
}

void data_stream_thread(int stream_id, std::vector<std::string> msg_sequence) {
    CoroutineContext ctx;

    auto taskA = process_A(ctx);
    auto taskB = process_B(ctx);
    auto taskC = process_C(ctx);

    taskA.resume();
    taskB.resume();
    taskC.resume();

    for (const auto &msg : msg_sequence) {
        std::this_thread::sleep_for(std::chrono::milliseconds(100));
        ctx.message_arrive(msg, "Data from " + msg + " (stream " + std::to_string(stream_id) + ")");
    }

    taskA.rethrow_exception_if_any();
    taskB.rethrow_exception_if_any();
    taskC.rethrow_exception_if_any();
}

int main() {
    std::vector<std::thread> threads;
    std::vector<std::vector<std::string>> test_scenarios = {
        {"B", "A", "C"}
    };

    for (int i = 0; i < test_scenarios.size(); ++i) {
        threads.emplace_back(data_stream_thread, i, test_scenarios[i]);
    }

    for (auto &t : threads) {
        t.join();
    }

    return 0;
}

文档信息

Search

    Table of Contents