Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions src/mediapipe_internal/graphqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,82 @@ GraphQueue::GraphQueue(const ::mediapipe::CalculatorGraphConfig& config, std::sh
SPDLOG_ERROR("Graph queue StartRun failed: {}", absStatus.ToString());
throw std::runtime_error(absStatus.ToString());
}
inferRequests.emplace_back(std::move(graphHelper));
this->inferRequests.emplace_back(std::move(graphHelper));
}
}
GraphQueue::~GraphQueue() {
for (auto& graphHelper : inferRequests) {
auto absStatus = graphHelper->graph->WaitUntilIdle();
GraphHelper::~GraphHelper() {
if (!graph) {
return;
}
auto absStatus = graph->WaitUntilIdle();
if (!absStatus.ok()) {
SPDLOG_DEBUG("GraphHelper WaitUntilIdle error: {}", absStatus.ToString());
}
absStatus = graph->CloseAllPacketSources();
if (!absStatus.ok()) {
SPDLOG_DEBUG("GraphHelper CloseAllPacketSources error: {}", absStatus.ToString());
}
absStatus = graph->WaitUntilDone();
if (!absStatus.ok()) {
SPDLOG_DEBUG("GraphHelper WaitUntilDone error: {}", absStatus.ToString());
}
graph->Cancel();
}

void GraphHelper::reinitialize(const ::mediapipe::CalculatorGraphConfig& config, const GraphSidePackets& sidePacketMaps) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit test that will ensure failed graphs are reusable is missing

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test

SPDLOG_DEBUG("Reinitializing graph after error");
// Tear down the old graph (best-effort, errors expected since graph is in bad state)
if (this->graph) {
auto absStatus = this->graph->CloseAllPacketSources();
Comment on lines +120 to +124
if (!absStatus.ok()) {
SPDLOG_DEBUG("Graph queue WaitUntilIdle error: {}", absStatus.ToString());
SPDLOG_DEBUG("reinitialize: CloseAllPacketSources: {}", absStatus.ToString());
}
absStatus = graphHelper->graph->CloseAllPacketSources();
absStatus = this->graph->WaitUntilDone();
if (!absStatus.ok()) {
SPDLOG_DEBUG("Graph queue CloseAllPacketSources error: {}", absStatus.ToString());
SPDLOG_DEBUG("reinitialize: WaitUntilDone: {}", absStatus.ToString());
}
absStatus = graphHelper->graph->WaitUntilDone();
this->graph->Cancel();
}
// Create fresh graph
graph = std::make_unique<::mediapipe::CalculatorGraph>();
currentTimestamp = ::mediapipe::Timestamp(0);

auto absStatus = graph->Initialize(config);
if (!absStatus.ok()) {
SPDLOG_ERROR("Graph reinitialize: Initialize failed: {}", absStatus.ToString());
graph.reset();
return;
}
for (const auto& [streamName, holder] : outStreamObservers) {
absStatus = graph->ObserveOutputStream(streamName, [holder](const ::mediapipe::Packet& packet) -> absl::Status {
return holder->current->handlePacket(packet);
});
if (!absStatus.ok()) {
SPDLOG_DEBUG("Graph queue WaitUntilDone error: {}", absStatus.ToString());
SPDLOG_ERROR("Graph reinitialize: ObserveOutputStream failed: {}", absStatus.ToString());
graph.reset();
return;
}
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove graph.reset.

graphHelper->graph->Cancel();
graphHelper->graph.reset();
}
// Reset observers to null sentinel
for (const auto& [streamName, holder] : outStreamObservers) {
holder->current = std::make_shared<NullOutputStreamObserver>();
}
// Reset execution contexts
for (auto& [nodeName, ctx] : genAiExecutionContextMap) {
ctx->reset();
}
std::map<std::string, mediapipe::Packet> inputSidePackets;
buildInputSidePackets(inputSidePackets, sidePacketMaps);
inputSidePackets[LLM_EXECUTION_CONTEXT_SESSION_SIDE_PACKET_TAG] =
mediapipe::MakePacket<GenAiExecutionContextMap>(genAiExecutionContextMap)
.At(::mediapipe::Timestamp(STARTING_TIMESTAMP_VALUE));
absStatus = graph->StartRun(inputSidePackets);
if (!absStatus.ok()) {
SPDLOG_ERROR("Graph reinitialize: StartRun failed: {}", absStatus.ToString());
graph.reset();
return;
}
SPDLOG_DEBUG("Graph reinitialized successfully");
}
GraphQueue::~GraphQueue() = default;
} // namespace ovms
33 changes: 33 additions & 0 deletions src/mediapipe_internal/graphqueue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,40 @@ struct GraphHelper {
genAiExecutionContextMap(std::move(gh.genAiExecutionContextMap)),
currentTimestamp(gh.currentTimestamp) {}
GraphHelper& operator=(GraphHelper&&) = delete;
~GraphHelper();
// Tears down the current (errored) graph and rebuilds a fresh one
// with the same observers and side packets. Called when inference
// encounters a graph error to avoid returning a poisoned graph to the pool.
void reinitialize(const ::mediapipe::CalculatorGraphConfig& config, const GraphSidePackets& sidePacketMaps);
};

// RAII guard that reinitializes the graph if inference exits with an error.
// Construct before the first graph interaction (packet push). Call dismiss()
// on the success path. If not dismissed, the destructor rebuilds the graph
// so the next request from the pool gets a clean graph.
class GraphReinitGuard {
GraphHelper& helper;
const ::mediapipe::CalculatorGraphConfig& config;
const GraphSidePackets& sidePacketMaps;
bool dismissed = false;

public:
GraphReinitGuard(GraphHelper& helper,
const ::mediapipe::CalculatorGraphConfig& config,
const GraphSidePackets& sidePacketMaps) :
helper(helper),
config(config),
sidePacketMaps(sidePacketMaps) {}
void dismiss() { dismissed = true; }
~GraphReinitGuard() {
if (!dismissed) {
helper.reinitialize(config, sidePacketMaps);
}
}
Comment on lines +99 to +109
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case we could only log error anyway. Will add that.

GraphReinitGuard(const GraphReinitGuard&) = delete;
GraphReinitGuard& operator=(const GraphReinitGuard&) = delete;
};

// we need to keep Graph alive during MP reload hence shared_ptr
class GraphQueue : public Queue<std::shared_ptr<GraphHelper>> {
std::shared_ptr<GraphSidePackets> sidePacketMaps;
Expand Down
12 changes: 12 additions & 0 deletions src/mediapipe_internal/mediapipegraphexecutor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ class MediapipeGraphExecutor {
guard->graphHelper->outStreamObservers.at(name)->current = std::make_shared<MyFunctor<RequestType, ResponseType>>(name, this->outputTypes.at(name), *this, *request, *response);
}

// Guard: if inference fails after this point, reinitialize the graph
// so the next request from the pool gets a clean graph (not poisoned).
GraphReinitGuard reinitGuard(*this->guard->graphHelper, this->config, this->sidePacketMaps);

size_t numberOfPacketsCreated = 0;
auto ovms_status = createAndPushPacketsImpl(
std::shared_ptr<const RequestType>(request, [](const RequestType*) {}),
Expand Down Expand Up @@ -227,6 +231,8 @@ class MediapipeGraphExecutor {
}
resetLlmExecutionContexts(this->guard->graphHelper->genAiExecutionContextMap);
MP_RETURN_ON_FAIL(status, "graph wait until idle", mediapipeAbslToOvmsStatus(status.code()));
// Success — dismiss the guard, graph is healthy
reinitGuard.dismiss();
// Increment timestamp for next request reusing this graph from the queue
this->guard->graphHelper->currentTimestamp = ::mediapipe::Timestamp(this->guard->graphHelper->currentTimestamp.Value() + 1);
SPDLOG_DEBUG("Received all output stream packets for graph: {}", this->name);
Expand Down Expand Up @@ -393,6 +399,10 @@ class MediapipeGraphExecutor {
executionContext, this->mediapipeServableMetricReporter);
}

// Guard: if streaming inference fails, reinitialize the graph
// so the next request from the pool gets a clean graph (not poisoned).
GraphReinitGuard reinitGuard(*this->guard->graphHelper, this->config, this->sidePacketMaps);

size_t numberOfPacketsCreated = 0;
{
OVMS_PROFILE_SCOPE("Mediapipe graph deserializing first request");
Expand Down Expand Up @@ -450,6 +460,8 @@ class MediapipeGraphExecutor {
}
resetLlmExecutionContexts(this->guard->graphHelper->genAiExecutionContextMap);
MP_RETURN_ON_FAIL(status, "graph wait until idle", mediapipeAbslToOvmsStatus(status.code()));
// Success — dismiss the guard, graph is healthy
reinitGuard.dismiss();
// Increment timestamp for next request reusing this graph from the queue
this->guard->graphHelper->currentTimestamp = ::mediapipe::Timestamp(this->guard->graphHelper->currentTimestamp.Value() + 1);
SPDLOG_DEBUG("Graph {}: Done streaming execution (queue path)", this->name);
Expand Down