rtc_plugins/rtc_plugins.cpp

111 lines
3.5 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#define IMPLEMENT_NUMPY_API
#include "util/RTCContext.h"
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h> // pybind11 的 NumPy 支持
#include <pybind11/detail/common.h>
namespace py = pybind11;
#include "util/RTCContext.h"
// 提供转换接口
void** get_numpy_api() {
return (void**)RTC_PLUGINS_ARRAY_API;
}
// 初始化 NumPy适配 pybind11
void init_numpy() {
if (import_array() < 0) {
throw py::bind_already_set(); // 自动捕获 NumPy 初始化错误
}
std::cout << "NumPy API addr: " << PyArray_API << std::endl;
}
int init(const char* selfUserId, const char* selfDisplayName, const char* selfRoomId, py::object callback) {
if (!PyArray_API) {
std::cout << "PyArray_API is null in outer init" << std::endl;
} else {
std::cout << "PyArray_API is not null in outer init" << std::endl;
}
RTCContext::instance().setPyCallback(callback);
bool res = RTCContext::instance().init(selfUserId, selfDisplayName, selfRoomId);
if (res) {
return 0;
} else {
return -1;
}
}
int initRecv(const char* destRoomId, const char* srcUserId, const int16_t destChannelIndex) {
if (!PyArray_API) {
std::cout << "PyArray_API is null in outer initRecv" << std::endl;
} else {
std::cout << "PyArray_API is not null in outer initRecv" << std::endl;
}
bool res = RTCContext::instance().initRecv(destRoomId, srcUserId, destChannelIndex);
if (res) {
return 0;
} else {
return -1;
}
}
int initSend(const char* srcRoomId, const char* destRoomId, const int16_t destChannelIndex, const int16_t channelNum) {
bool res = RTCContext::instance().initSend(srcRoomId, destRoomId, destChannelIndex, channelNum);
if (res) {
return 0;
} else {
return -1;
}
}
// NumPy 数据交互(关键修改)
py::array_t<int16_t> getNumpyData() {
return py::array_t<int16_t>(
RTCContext::instance().getNumpyData() // 假设返回的是已有 NumPy 数组
);
}
int sendCustomAudioData(int16_t destChannelIndex, py::array_t<int16_t> inputArray,
int32_t sampleRate, uint64_t channelNum, uint64_t dataLen) {
py::gil_scoped_release release;
auto buf = inputArray.request();
if (buf.size != dataLen) {
throw py::value_error("Array length does not match dataLen");
}
return RTCContext::instance().sendCustomAudioData(
destChannelIndex, buf.ptr, sampleRate, channelNum, dataLen
);
}
py::list getListData() {
return RTCContext::instance().getListData();
}
int getSize() {
return RTCContext::instance().getSize();
}
RetAudioFrame getData() {
return RTCContext::instance().getData();
}
int16_t getDataCount() {
return RTCContext::instance().getDataCount();
}
PYBIND11_MODULE(rtc_plugins, m) {
init_numpy();
// 可选:暴露 RetAudioFrame 类(需额外绑定)
py::class_<RetAudioFrame>(m, "RetAudioFrame")
.def_readwrite("data", &RetAudioFrame::data)
.def_readwrite("dataCount", &RetAudioFrame::dataCount)
.def_readwrite("sampleRate", &RetAudioFrame::sampleRate)
.def_readwrite("numChannels", &RetAudioFrame::numChannels)
.def_readwrite("channelIndex", &RetAudioFrame::channelIndex);
m.def("init", &init);
m.def("initRecv", &initRecv);
m.def("initSend", &initSend);
m.def("sendCustomAudioData", &sendCustomAudioData);
m.def("getSize", &getSize);
m.def("getData", &getData);
m.def("getNumpyData", &getNumpyData);
m.def("getListData", &getListData);
m.def("getDataCount", &getDataCount);
}