From d999ffc12211f5d5c6cd2196dd525e69e2fc0f58 Mon Sep 17 00:00:00 2001 From: Vernon Mauery Date: Thu, 25 Oct 2018 09:16:05 -0700 Subject: netipmid: use shared_ptr on messages instead of unique_ptr+references Messages were being created and held by unique_ptr objects and then shared via reference. This is dangerous and sidesteps the whole point of a unique_ptr, which is to enforce single ownership. This replaces the usage with a shared_ptr, which denotes shared ownership. Change-Id: I19ed2693f5a0f5ce47d720ed255fa05bdf3844f8 Signed-off-by: Vernon Mauery --- message_handler.cpp | 88 ++++++++++++++++++++++++++--------------------------- message_handler.hpp | 15 ++++----- message_parsers.cpp | 71 +++++++++++++++++++++--------------------- message_parsers.hpp | 27 +++++++++------- sd_event_loop.cpp | 8 ++--- 5 files changed, 107 insertions(+), 102 deletions(-) diff --git a/message_handler.cpp b/message_handler.cpp index c5d5d4e..b9cc6ab 100644 --- a/message_handler.cpp +++ b/message_handler.cpp @@ -16,7 +16,7 @@ namespace message { -std::unique_ptr Handler::receive() +std::shared_ptr Handler::receive() { std::vector packet; auto readStatus = 0; @@ -32,7 +32,7 @@ std::unique_ptr Handler::receive() } // Unflatten the packet - std::unique_ptr message; + std::shared_ptr message; std::tie(message, sessionHeader) = parser::unflatten(packet); auto session = std::get(singletonPool) @@ -46,18 +46,17 @@ std::unique_ptr Handler::receive() } template <> -std::unique_ptr - Handler::createResponse(std::vector& output, - Message& inMessage) +std::shared_ptr Handler::createResponse( + std::vector& output, std::shared_ptr inMessage) { - auto outMessage = std::make_unique(); + auto outMessage = std::make_shared(); outMessage->payloadType = PayloadType::IPMI; outMessage->payload.resize(sizeof(LAN::header::Response) + output.size() + sizeof(LAN::trailer::Response)); auto reqHeader = - reinterpret_cast(inMessage.payload.data()); + reinterpret_cast(inMessage->payload.data()); auto respHeader = reinterpret_cast(outMessage->payload.data()); @@ -84,22 +83,23 @@ std::unique_ptr return outMessage; } -std::unique_ptr Handler::executeCommand(Message& inMessage) +std::shared_ptr + Handler::executeCommand(std::shared_ptr inMessage) { // Get the CommandID to map into the command table auto command = getCommand(inMessage); std::vector output{}; - if (inMessage.payloadType == PayloadType::IPMI) + if (inMessage->payloadType == PayloadType::IPMI) { - if (inMessage.payload.size() < + if (inMessage->payload.size() < (sizeof(LAN::header::Request) + sizeof(LAN::trailer::Request))) { return nullptr; } - auto start = inMessage.payload.begin() + sizeof(LAN::header::Request); - auto end = inMessage.payload.end() - sizeof(LAN::trailer::Request); + auto start = inMessage->payload.begin() + sizeof(LAN::header::Request); + auto end = inMessage->payload.end() - sizeof(LAN::trailer::Request); std::vector inPayload(start, end); output = std::get(singletonPool) @@ -108,12 +108,12 @@ std::unique_ptr Handler::executeCommand(Message& inMessage) else { output = std::get(singletonPool) - .executeCommand(command, inMessage.payload, *this); + .executeCommand(command, inMessage->payload, *this); } - std::unique_ptr outMessage = nullptr; + std::shared_ptr outMessage = nullptr; - switch (inMessage.payloadType) + switch (inMessage->payloadType) { case PayloadType::IPMI: outMessage = createResponse(output, inMessage); @@ -135,34 +135,34 @@ std::unique_ptr Handler::executeCommand(Message& inMessage) break; } - outMessage->isPacketEncrypted = inMessage.isPacketEncrypted; - outMessage->isPacketAuthenticated = inMessage.isPacketAuthenticated; - outMessage->rcSessionID = inMessage.rcSessionID; - outMessage->bmcSessionID = inMessage.bmcSessionID; + outMessage->isPacketEncrypted = inMessage->isPacketEncrypted; + outMessage->isPacketAuthenticated = inMessage->isPacketAuthenticated; + outMessage->rcSessionID = inMessage->rcSessionID; + outMessage->bmcSessionID = inMessage->bmcSessionID; return outMessage; } -uint32_t Handler::getCommand(Message& message) +uint32_t Handler::getCommand(std::shared_ptr message) { uint32_t command = 0; - command |= (static_cast(message.payloadType) << 16); - if (message.payloadType == PayloadType::IPMI) + command |= (static_cast(message->payloadType) << 16); + if (message->payloadType == PayloadType::IPMI) { command |= - ((reinterpret_cast(message.payload.data())) + ((reinterpret_cast(message->payload.data())) ->netfn) << 8; command |= - (reinterpret_cast(message.payload.data())) + (reinterpret_cast(message->payload.data())) ->cmd; } return command; } -void Handler::send(Message& outMessage) +void Handler::send(std::shared_ptr outMessage) { auto session = std::get(singletonPool).getSession(sessionID); @@ -188,17 +188,16 @@ void Handler::setChannelInSession() const void Handler::sendSOLPayload(const std::vector& input) { - Message outMessage; - auto session = std::get(singletonPool).getSession(sessionID); - outMessage.payloadType = PayloadType::SOL; - outMessage.payload = input; - outMessage.isPacketEncrypted = session->isCryptAlgoEnabled(); - outMessage.isPacketAuthenticated = session->isIntegrityAlgoEnabled(); - outMessage.rcSessionID = session->getRCSessionID(); - outMessage.bmcSessionID = sessionID; + auto outMessage = std::make_shared(); + outMessage->payloadType = PayloadType::SOL; + outMessage->payload = input; + outMessage->isPacketEncrypted = session->isCryptAlgoEnabled(); + outMessage->isPacketAuthenticated = session->isIntegrityAlgoEnabled(); + outMessage->rcSessionID = session->getRCSessionID(); + outMessage->bmcSessionID = sessionID; send(outMessage); } @@ -206,22 +205,21 @@ void Handler::sendSOLPayload(const std::vector& input) void Handler::sendUnsolicitedIPMIPayload(uint8_t netfn, uint8_t cmd, const std::vector& output) { - Message outMessage; - auto session = std::get(singletonPool).getSession(sessionID); - outMessage.payloadType = PayloadType::IPMI; - outMessage.isPacketEncrypted = session->isCryptAlgoEnabled(); - outMessage.isPacketAuthenticated = session->isIntegrityAlgoEnabled(); - outMessage.rcSessionID = session->getRCSessionID(); - outMessage.bmcSessionID = sessionID; + auto outMessage = std::make_shared(); + outMessage->payloadType = PayloadType::IPMI; + outMessage->isPacketEncrypted = session->isCryptAlgoEnabled(); + outMessage->isPacketAuthenticated = session->isIntegrityAlgoEnabled(); + outMessage->rcSessionID = session->getRCSessionID(); + outMessage->bmcSessionID = sessionID; - outMessage.payload.resize(sizeof(LAN::header::Request) + output.size() + - sizeof(LAN::trailer::Request)); + outMessage->payload.resize(sizeof(LAN::header::Request) + output.size() + + sizeof(LAN::trailer::Request)); auto respHeader = - reinterpret_cast(outMessage.payload.data()); + reinterpret_cast(outMessage->payload.data()); // Add IPMI LAN Message Request Header respHeader->rsaddr = LAN::requesterBMCAddress; @@ -235,12 +233,12 @@ void Handler::sendUnsolicitedIPMIPayload(uint8_t netfn, uint8_t cmd, // Copy the output by the execution of the command std::copy(output.begin(), output.end(), - outMessage.payload.begin() + assembledSize); + outMessage->payload.begin() + assembledSize); assembledSize += output.size(); // Add the IPMI LAN Message Trailer auto trailer = reinterpret_cast( - outMessage.payload.data() + assembledSize); + outMessage->payload.data() + assembledSize); // Calculate the checksum for the field rqaddr in the header to the // command data, 3 corresponds to size of the fields before rqaddr( rsaddr, diff --git a/message_handler.hpp b/message_handler.hpp index 063f8d2..3c99660 100644 --- a/message_handler.hpp +++ b/message_handler.hpp @@ -6,6 +6,7 @@ #include "sol/console_buffer.hpp" #include +#include #include namespace message @@ -38,7 +39,7 @@ class Handler * @return IPMI Message on success and nullptr on failure * */ - std::unique_ptr receive(); + std::shared_ptr receive(); /** * @brief Process the incoming IPMI message @@ -51,7 +52,7 @@ class Handler * * @return Outgoing message on success and nullptr on failure */ - std::unique_ptr executeCommand(Message& inMessage); + std::shared_ptr executeCommand(std::shared_ptr inMessage); /** @brief Send the outgoing message * @@ -60,7 +61,7 @@ class Handler * * @param[in] outMessage - Outgoing Message */ - void send(Message& outMessage); + void send(std::shared_ptr outMessage); /** @brief Set socket channel in session object */ void setChannelInSession() const; @@ -109,10 +110,10 @@ class Handler * @return Outgoing message on success and nullptr on failure */ template - std::unique_ptr createResponse(std::vector& output, - Message& inMessage) + std::shared_ptr createResponse(std::vector& output, + std::shared_ptr inMessage) { - auto outMessage = std::make_unique(); + auto outMessage = std::make_shared(); outMessage->payloadType = T; outMessage->payload = output; return outMessage; @@ -125,7 +126,7 @@ class Handler * * @return Command ID in the incoming message */ - uint32_t getCommand(Message& message); + uint32_t getCommand(std::shared_ptr message); /** * @brief Calculate 8 bit 2's complement checksum diff --git a/message_parsers.cpp b/message_parsers.cpp index 7497747..ddb9d01 100644 --- a/message_parsers.cpp +++ b/message_parsers.cpp @@ -14,7 +14,7 @@ namespace message namespace parser { -std::tuple, SessionHeader> +std::tuple, SessionHeader> unflatten(std::vector& inPacket) { // Check if the packet has atleast the size of the RMCP Header @@ -54,8 +54,8 @@ std::tuple, SessionHeader> } } -std::vector flatten(Message& outMessage, SessionHeader authType, - session::Session& session) +std::vector flatten(std::shared_ptr outMessage, + SessionHeader authType, session::Session& session) { // Call the flatten routine based on the header type switch (authType) @@ -80,7 +80,7 @@ std::vector flatten(Message& outMessage, SessionHeader authType, namespace ipmi15parser { -std::unique_ptr unflatten(std::vector& inPacket) +std::shared_ptr unflatten(std::vector& inPacket) { // Check if the packet has atleast the Session Header if (inPacket.size() < sizeof(SessionHeader_t)) @@ -88,7 +88,7 @@ std::unique_ptr unflatten(std::vector& inPacket) throw std::runtime_error("IPMI1.5 Session Header Missing"); } - auto message = std::make_unique(); + auto message = std::make_shared(); auto header = reinterpret_cast(inPacket.data()); @@ -107,7 +107,8 @@ std::unique_ptr unflatten(std::vector& inPacket) return message; } -std::vector flatten(Message& outMessage, session::Session& session) +std::vector flatten(std::shared_ptr outMessage, + session::Session& session) { std::vector packet(sizeof(SessionHeader_t)); @@ -120,13 +121,13 @@ std::vector flatten(Message& outMessage, session::Session& session) header->base.format.formatType = static_cast(parser::SessionHeader::IPMI15); header->sessSeqNum = 0; - header->sessId = endian::to_ipmi(outMessage.rcSessionID); + header->sessId = endian::to_ipmi(outMessage->rcSessionID); - header->payloadLength = static_cast(outMessage.payload.size()); + header->payloadLength = static_cast(outMessage->payload.size()); // Insert the Payload into the Packet - packet.insert(packet.end(), outMessage.payload.begin(), - outMessage.payload.end()); + packet.insert(packet.end(), outMessage->payload.begin(), + outMessage->payload.end()); // Insert the Session Trailer packet.resize(packet.size() + sizeof(SessionTrailer_t)); @@ -142,7 +143,7 @@ std::vector flatten(Message& outMessage, session::Session& session) namespace ipmi20parser { -std::unique_ptr unflatten(std::vector& inPacket) +std::shared_ptr unflatten(std::vector& inPacket) { // Check if the packet has atleast the Session Header if (inPacket.size() < sizeof(SessionHeader_t)) @@ -150,7 +151,7 @@ std::unique_ptr unflatten(std::vector& inPacket) throw std::runtime_error("IPMI2.0 Session Header Missing"); } - auto message = std::make_unique(); + auto message = std::make_shared(); auto header = reinterpret_cast(inPacket.data()); @@ -166,8 +167,7 @@ std::unique_ptr unflatten(std::vector& inPacket) if (message->isPacketAuthenticated) { - if (!(internal::verifyPacketIntegrity(inPacket, *(message.get()), - payloadLen))) + if (!(internal::verifyPacketIntegrity(inPacket, message, payloadLen))) { throw std::runtime_error("Packet Integrity check failed"); } @@ -178,7 +178,7 @@ std::unique_ptr unflatten(std::vector& inPacket) { // Assign the decrypted payload to the IPMI Message message->payload = - internal::decryptPayload(inPacket, *(message.get()), payloadLen); + internal::decryptPayload(inPacket, message, payloadLen); } else { @@ -190,7 +190,8 @@ std::unique_ptr unflatten(std::vector& inPacket) return message; } -std::vector flatten(Message& outMessage, session::Session& session) +std::vector flatten(std::shared_ptr outMessage, + session::Session& session) { std::vector packet(sizeof(SessionHeader_t)); @@ -201,8 +202,8 @@ std::vector flatten(Message& outMessage, session::Session& session) header->base.classOfMsg = parser::RMCP_MESSAGE_CLASS_IPMI; header->base.format.formatType = static_cast(parser::SessionHeader::IPMI20); - header->payloadType = static_cast(outMessage.payloadType); - header->sessId = endian::to_ipmi(outMessage.rcSessionID); + header->payloadType = static_cast(outMessage->payloadType); + header->sessId = endian::to_ipmi(outMessage->rcSessionID); // Add session sequence number internal::addSequenceNumber(packet, session); @@ -210,7 +211,7 @@ std::vector flatten(Message& outMessage, session::Session& session) size_t payloadLen = 0; // Encrypt the payload if needed - if (outMessage.isPacketEncrypted) + if (outMessage->isPacketEncrypted) { header->payloadType |= PAYLOAD_ENCRYPT_MASK; auto cipherPayload = internal::encryptPayload(outMessage); @@ -223,15 +224,15 @@ std::vector flatten(Message& outMessage, session::Session& session) else { header->payloadLength = - endian::to_ipmi(outMessage.payload.size()); - payloadLen = outMessage.payload.size(); + endian::to_ipmi(outMessage->payload.size()); + payloadLen = outMessage->payload.size(); // Insert the Payload into the Packet - packet.insert(packet.end(), outMessage.payload.begin(), - outMessage.payload.end()); + packet.insert(packet.end(), outMessage->payload.begin(), + outMessage->payload.end()); } - if (outMessage.isPacketAuthenticated) + if (outMessage->isPacketAuthenticated) { internal::addIntegrityData(packet, outMessage, payloadLen); } @@ -258,7 +259,8 @@ void addSequenceNumber(std::vector& packet, session::Session& session) } bool verifyPacketIntegrity(const std::vector& packet, - const Message& message, size_t payloadLen) + const std::shared_ptr message, + size_t payloadLen) { /* * Padding bytes are added to cause the number of bytes in the data range @@ -281,7 +283,7 @@ bool verifyPacketIntegrity(const std::vector& packet, } auto session = std::get(singletonPool) - .getSession(message.bmcSessionID); + .getSession(message->bmcSessionID); auto integrityAlgo = session->getIntegrityAlgo(); @@ -305,8 +307,8 @@ bool verifyPacketIntegrity(const std::vector& packet, return integrityAlgo->verifyIntegrityData(packet, length, integrityIter); } -void addIntegrityData(std::vector& packet, const Message& message, - size_t payloadLen) +void addIntegrityData(std::vector& packet, + const std::shared_ptr message, size_t payloadLen) { // The following logic calculates the number of padding bytes to be added to // IPMI packet. If needed each integrity Pad byte is set to FFh. @@ -322,7 +324,7 @@ void addIntegrityData(std::vector& packet, const Message& message, trailer->nextHeader = parser::RMCP_MESSAGE_CLASS_IPMI; auto session = std::get(singletonPool) - .getSession(message.bmcSessionID); + .getSession(message->bmcSessionID); auto integrityData = session->getIntegrityAlgo()->generateIntegrityData(packet); @@ -331,21 +333,22 @@ void addIntegrityData(std::vector& packet, const Message& message, } std::vector decryptPayload(const std::vector& packet, - const Message& message, size_t payloadLen) + const std::shared_ptr message, + size_t payloadLen) { auto session = std::get(singletonPool) - .getSession(message.bmcSessionID); + .getSession(message->bmcSessionID); return session->getCryptAlgo()->decryptPayload( packet, sizeof(SessionHeader_t), payloadLen); } -std::vector encryptPayload(Message& message) +std::vector encryptPayload(std::shared_ptr message) { auto session = std::get(singletonPool) - .getSession(message.bmcSessionID); + .getSession(message->bmcSessionID); - return session->getCryptAlgo()->encryptPayload(message.payload); + return session->getCryptAlgo()->encryptPayload(message->payload); } } // namespace internal diff --git a/message_parsers.hpp b/message_parsers.hpp index 0dae43b..b38da40 100644 --- a/message_parsers.hpp +++ b/message_parsers.hpp @@ -59,7 +59,7 @@ struct BasicHeader_t * header type. In case of failure nullptr and session header type * would be invalid. */ -std::tuple, SessionHeader> +std::tuple, SessionHeader> unflatten(std::vector& inPacket); /** @@ -72,8 +72,8 @@ std::tuple, SessionHeader> * * @return IPMI packet on success */ -std::vector flatten(Message& outMessage, SessionHeader authType, - session::Session& session); +std::vector flatten(std::shared_ptr outMessage, + SessionHeader authType, session::Session& session); } // namespace parser @@ -101,7 +101,7 @@ struct SessionTrailer_t * * @return IPMI message in the packet on success */ -std::unique_ptr unflatten(std::vector& inPacket); +std::shared_ptr unflatten(std::vector& inPacket); /** * @brief Flatten an IPMI message and generate the IPMI packet with the @@ -111,7 +111,8 @@ std::unique_ptr unflatten(std::vector& inPacket); * * @return IPMI packet on success */ -std::vector flatten(Message& outMessage, session::Session& session); +std::vector flatten(std::shared_ptr outMessage, + session::Session& session); } // namespace ipmi15parser @@ -147,7 +148,7 @@ struct SessionTrailer_t * * @return IPMI message in the packet on success */ -std::unique_ptr unflatten(std::vector& inPacket); +std::shared_ptr unflatten(std::vector& inPacket); /** * @brief Flatten an IPMI message and generate the IPMI packet with the @@ -157,7 +158,8 @@ std::unique_ptr unflatten(std::vector& inPacket); * * @return IPMI packet on success */ -std::vector flatten(Message& outMessage, session::Session& session); +std::vector flatten(std::shared_ptr outMessage, + session::Session& session); namespace internal { @@ -180,7 +182,8 @@ void addSequenceNumber(std::vector& packet, session::Session& session); * */ bool verifyPacketIntegrity(const std::vector& packet, - const Message& message, size_t payloadLen); + const std::shared_ptr message, + size_t payloadLen); /** * @brief Add Integrity data to the outgoing IPMI packet @@ -189,7 +192,8 @@ bool verifyPacketIntegrity(const std::vector& packet, * @param[in] message - IPMI Message populated for the outgoing packet * @param[in] payloadLen - Length of the IPMI payload */ -void addIntegrityData(std::vector& packet, const Message& message, +void addIntegrityData(std::vector& packet, + const std::shared_ptr message, size_t payloadLen); /** @@ -202,7 +206,8 @@ void addIntegrityData(std::vector& packet, const Message& message, * @return on successful completion, return the plain text payload */ std::vector decryptPayload(const std::vector& packet, - const Message& message, size_t payloadLen); + const std::shared_ptr message, + size_t payloadLen); /** * @brief Encrypt the plain text payload for the outgoing IPMI packet @@ -211,7 +216,7 @@ std::vector decryptPayload(const std::vector& packet, * * @return on successful completion, return the encrypted payload */ -std::vector encryptPayload(Message& message); +std::vector encryptPayload(std::shared_ptr message); } // namespace internal diff --git a/sd_event_loop.cpp b/sd_event_loop.cpp index a5fd41d..aa5224a 100644 --- a/sd_event_loop.cpp +++ b/sd_event_loop.cpp @@ -29,24 +29,22 @@ static int udp623Handler(sd_event_source* es, int fd, uint32_t revents, // Initialize the Message Handler with the socket channel message::Handler msgHandler(channelPtr); - std::unique_ptr inMessage; - // Read the incoming IPMI packet - inMessage = msgHandler.receive(); + std::shared_ptr inMessage(msgHandler.receive()); if (inMessage == nullptr) { return 0; } // Execute the Command - auto outMessage = msgHandler.executeCommand(*(inMessage.get())); + auto outMessage = msgHandler.executeCommand(inMessage); if (outMessage == nullptr) { return 0; } // Send the response IPMI Message - msgHandler.send(*(outMessage.get())); + msgHandler.send(outMessage); } catch (std::exception& e) { -- cgit v1.2.1