diff --git a/binding/core.py b/binding/core.py index ac28707..d8de580 100644 --- a/binding/core.py +++ b/binding/core.py @@ -16,8 +16,9 @@ # along with this program. If not, see import os -from rpc import OstinatoRpcChannel, OstinatoRpcController +from rpc import OstinatoRpcChannel, OstinatoRpcController, RpcError import protocols.protocol_pb2 as ost_pb +from __init__ import __version__ class DroneProxy(object): @@ -41,6 +42,12 @@ class DroneProxy(object): def connect(self): self.channel.connect(self.host, self.port) + ver = ost_pb.VersionInfo() + ver.version = __version__ + compat = self.checkVersion(ver) + if compat.result == ost_pb.VersionCompatibility.kIncompatible: + raise RpcError('incompatible version %s (%s)' % + (ver.version, compat.notes)) def disconnect(self): self.channel.disconnect() diff --git a/binding/rpc.py b/binding/rpc.py index 2078bb3..2eb721b 100644 --- a/binding/rpc.py +++ b/binding/rpc.py @@ -52,12 +52,12 @@ class OstinatoRpcChannel(RpcChannel): def __init__(self): self.log = logging.getLogger(__name__) self.log.debug('opening socket') - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) def connect(self, host, port): self.peer = '%s:%d' % (host, port) self.log.debug('connecting to %s', self.peer) try: + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.connect((host, port)) except socket.error as e: error = 'ERROR: Unable to connect to Drone %s (%s)' % ( diff --git a/client/portgroup.cpp b/client/portgroup.cpp index e411652..8e20735 100644 --- a/client/portgroup.cpp +++ b/client/portgroup.cpp @@ -33,6 +33,7 @@ along with this program. If not, see using ::google::protobuf::NewCallback; extern QMainWindow *mainWindow; +extern char *version; quint32 PortGroup::mPortGroupAllocId = 0; @@ -47,6 +48,8 @@ PortGroup::PortGroup(QHostAddress ip, quint16 port) statsController = new PbRpcController(portIdList_, portStatsList_); isGetStatsPending_ = false; + compat = kUnknown; + reconnect = false; reconnectAfter = kMinReconnectWaitTime; reconnectTimer = new QTimer(this); @@ -112,19 +115,62 @@ void PortGroup::on_rpcChannel_stateChanged(QAbstractSocket::SocketState state) void PortGroup::on_rpcChannel_connected() { - OstProto::Void *void_ = new OstProto::Void; - OstProto::PortIdList *portIdList = new OstProto::PortIdList; + OstProto::VersionInfo *verInfo = new OstProto::VersionInfo; + OstProto::VersionCompatibility *verCompat = + new OstProto::VersionCompatibility; qDebug("connected\n"); emit portGroupDataChanged(mPortGroupId); reconnectAfter = kMinReconnectWaitTime; - qDebug("requesting portlist ..."); + qDebug("requesting version check ..."); + verInfo->set_version(version); - PbRpcController *controller = new PbRpcController(void_, portIdList); - serviceStub->getPortIdList(controller, void_, portIdList, - NewCallback(this, &PortGroup::processPortIdList, controller)); + PbRpcController *controller = new PbRpcController(verInfo, verCompat); + serviceStub->checkVersion(controller, verInfo, verCompat, + NewCallback(this, &PortGroup::processVersionCompatibility, + controller)); +} + +void PortGroup::processVersionCompatibility(PbRpcController *controller) +{ + OstProto::VersionCompatibility *verCompat + = static_cast(controller->response()); + + Q_ASSERT(verCompat != NULL); + + qDebug("got version result ..."); + + if (controller->Failed()) + { + qDebug("%s: rpc failed(%s)", __FUNCTION__, + qPrintable(controller->ErrorString())); + goto _error_exit; + } + + if (verCompat->result() == OstProto::VersionCompatibility::kIncompatible) { + qWarning("incompatible version %s (%s)", version, + qPrintable(QString::fromStdString(verCompat->notes()))); + compat = kIncompatible; + emit portGroupDataChanged(mPortGroupId); + goto _error_exit; + } + + compat = kCompatible; + + { + OstProto::Void *void_ = new OstProto::Void; + OstProto::PortIdList *portIdList = new OstProto::PortIdList; + + qDebug("requesting portlist ..."); + PbRpcController *controller = new PbRpcController(void_, portIdList); + serviceStub->getPortIdList(controller, void_, portIdList, + NewCallback(this, &PortGroup::processPortIdList, controller)); + } + +_error_exit: + delete controller; } void PortGroup::on_rpcChannel_disconnected() @@ -152,6 +198,9 @@ void PortGroup::on_rpcChannel_error(QAbstractSocket::SocketError socketError) qDebug("%s: error %d", __FUNCTION__, socketError); emit portGroupDataChanged(mPortGroupId); + if (socketError == QAbstractSocket::RemoteHostClosedError) + reconnect = false; + qDebug("%s: state %d", __FUNCTION__, rpcChannel->state()); if ((rpcChannel->state() == QAbstractSocket::UnconnectedState) && reconnect) { @@ -188,7 +237,8 @@ void PortGroup::processPortIdList(PbRpcController *controller) if (controller->Failed()) { - qDebug("%s: rpc failed", __FUNCTION__); + qDebug("%s: rpc failed(%s)", __FUNCTION__, + qPrintable(controller->ErrorString())); goto _error_exit; } @@ -239,7 +289,8 @@ void PortGroup::processPortConfigList(PbRpcController *controller) if (controller->Failed()) { - qDebug("%s: rpc failed", __FUNCTION__); + qDebug("%s: rpc failed(%s)", __FUNCTION__, + qPrintable(controller->ErrorString())); goto _error_exit; } @@ -369,7 +420,8 @@ void PortGroup::processModifyPortAck(PbRpcController *controller) if (controller->Failed()) { - qDebug("%s: rpc failed", __FUNCTION__); + qDebug("%s: rpc failed(%s)", __FUNCTION__, + qPrintable(controller->ErrorString())); goto _exit; } @@ -400,7 +452,8 @@ void PortGroup::processUpdatedPortConfig(PbRpcController *controller) if (controller->Failed()) { - qDebug("%s: rpc failed", __FUNCTION__); + qDebug("%s: rpc failed(%s)", __FUNCTION__, + qPrintable(controller->ErrorString())); goto _exit; } @@ -450,7 +503,8 @@ void PortGroup::processStreamIdList(int portIndex, PbRpcController *controller) if (controller->Failed()) { - qDebug("%s: rpc failed", __FUNCTION__); + qDebug("%s: rpc failed(%s)", __FUNCTION__, + qPrintable(controller->ErrorString())); goto _exit; } @@ -521,7 +575,8 @@ void PortGroup::processStreamConfigList(int portIndex, if (controller->Failed()) { - qDebug("%s: rpc failed", __FUNCTION__); + qDebug("%s: rpc failed(%s)", __FUNCTION__, + qPrintable(controller->ErrorString())); goto _exit; } @@ -781,7 +836,8 @@ void PortGroup::processPortStatsList() if (statsController->Failed()) { - qDebug("%s: rpc failed", __FUNCTION__); + qDebug("%s: rpc failed(%s)", __FUNCTION__, + qPrintable(statsController->ErrorString())); goto _error_exit; } diff --git a/client/portgroup.h b/client/portgroup.h index b4d7580..ccb8db8 100644 --- a/client/portgroup.h +++ b/client/portgroup.h @@ -43,6 +43,7 @@ class PortGroup : public QObject { Q_OBJECT private: + enum { kIncompatible, kCompatible, kUnknown } compat; static quint32 mPortGroupAllocId; quint32 mPortGroupId; QString mUserAlias; // user defined @@ -69,9 +70,16 @@ public: quint16 port = DEFAULT_SERVER_PORT); ~PortGroup(); - void connectToHost() { reconnect = true; rpcChannel->establish(); } - void connectToHost(QHostAddress ip, quint16 port) - { reconnect = true; rpcChannel->establish(ip, port); } + void connectToHost() { + reconnect = true; + compat = kUnknown; + rpcChannel->establish(); + } + void connectToHost(QHostAddress ip, quint16 port) { + reconnect = true; + compat = kUnknown; + rpcChannel->establish(ip, port); + } void disconnectFromHost() { reconnect = false; rpcChannel->tearDown(); } int numPorts() const { return mPorts.size(); } @@ -84,9 +92,13 @@ public: { return rpcChannel->serverAddress(); } quint16 serverPort() const { return rpcChannel->serverPort(); } - QAbstractSocket::SocketState state() const - { return rpcChannel->state(); } + QAbstractSocket::SocketState state() const { + if (compat == kIncompatible) + return QAbstractSocket::SocketState(-1); + return rpcChannel->state(); + } + void processVersionCompatibility(PbRpcController *controller); void processPortIdList(PbRpcController *controller); void processPortConfigList(PbRpcController *controller); diff --git a/common/protocol.proto b/common/protocol.proto index ad9477a..2039bb4 100644 --- a/common/protocol.proto +++ b/common/protocol.proto @@ -21,6 +21,19 @@ package OstProto; option cc_generic_services = true; option py_generic_services = true; +message VersionInfo { + required string version = 1; +} + +message VersionCompatibility { + enum Compatibility { + kIncompatible = 0; + kCompatible = 1; + } + required Compatibility result = 1; + optional string notes = 2; +} + message StreamId { required uint32 id = 1; } @@ -246,5 +259,7 @@ service OstService { rpc getStats(PortIdList) returns (PortStatsList); rpc clearStats(PortIdList) returns (Ack); + + rpc checkVersion(VersionInfo) returns (VersionCompatibility); } diff --git a/rpc/pbrpccontroller.h b/rpc/pbrpccontroller.h index af9c292..eff2324 100644 --- a/rpc/pbrpccontroller.h +++ b/rpc/pbrpccontroller.h @@ -44,7 +44,12 @@ public: ::google::protobuf::Message* response() { return response_; } // Client Side Methods - void Reset() { failed = false; blob = NULL; errStr = ""; } + void Reset() { + failed = false; + disconnect = false; + blob = NULL; + errStr = ""; + } bool Failed() const { return failed; } void StartCancel() { /*! \todo (MED) */} std::string ErrorText() const { return errStr.toStdString(); } @@ -59,6 +64,12 @@ public: void NotifyOnCancel(::google::protobuf::Closure* /* callback */) { /*! \todo (MED) */ } + void TriggerDisconnect() { + disconnect = true; + } + bool Disconnect() const { + return disconnect; + } // srivatsp added QIODevice* binaryBlob() { return blob; }; @@ -66,6 +77,7 @@ public: private: bool failed; + bool disconnect; QIODevice *blob; QString errStr; ::google::protobuf::Message *request_; diff --git a/rpc/rpcconn.cpp b/rpc/rpcconn.cpp index 7273df7..3d91c75 100644 --- a/rpc/rpcconn.cpp +++ b/rpc/rpcconn.cpp @@ -50,6 +50,8 @@ RpcConnection::RpcConnection(int socketDescriptor, isPending = false; pendingMethodId = -1; // don't care as long as isPending is false + + isCompatCheckDone = false; } RpcConnection::~RpcConnection() @@ -180,7 +182,13 @@ void RpcConnection::sendRpcReply(PbRpcController *controller) response->SerializeToZeroCopyStream(outStream); outStream->Flush(); + if (pendingMethodId == 15) + isCompatCheckDone = true; + _exit: + if (controller->Disconnect()) + clientSock->disconnectFromHost(); + delete controller; isPending = false; } @@ -210,6 +218,7 @@ void RpcConnection::on_clientSock_dataAvail() const ::google::protobuf::MethodDescriptor *methodDesc; ::google::protobuf::Message *req, *resp; PbRpcController *controller; + QString error; // Do we have enough bytes for a msg header? // If yes, peek into the header and get msg length @@ -241,6 +250,23 @@ void RpcConnection::on_clientSock_dataAvail() { qDebug("server(%s): unexpected msg type %d (expected %d)", __FUNCTION__, type, PB_MSG_TYPE_REQUEST); + error = QString("unexpected msg type %1; expected %2") + .arg(type).arg(PB_MSG_TYPE_REQUEST); + goto _error_exit; + } + + // If RPC is not checkVersion, ensure compat check is already done + if (!isCompatCheckDone && method != 15) { + qDebug("server(%s): version compatibility check pending", + __FUNCTION__); + error = "version compatibility check pending"; + goto _error_exit; + } + + if (method >= service->GetDescriptor()->method_count()) + { + qDebug("server(%s): invalid method id %d", __FUNCTION__, method); + error = QString("invalid RPC method %1").arg(method); goto _error_exit; } @@ -248,13 +274,18 @@ void RpcConnection::on_clientSock_dataAvail() if (!methodDesc) { qDebug("server(%s): invalid method id %d", __FUNCTION__, method); - goto _error_exit; //! \todo Return Error to client + error = QString("invalid RPC method %1").arg(method); + goto _error_exit; } if (isPending) { qDebug("server(%s): rpc pending, try again", __FUNCTION__); - goto _error_exit; //! \todo Return Error to client + error = QString("RPC %1() is pending; only one RPC allowed at a time; " + "try again!").arg(QString::fromStdString( + service->GetDescriptor()->method( + pendingMethodId)->name())); + goto _error_exit; } pendingMethodId = method; @@ -278,7 +309,11 @@ void RpcConnection::on_clientSock_dataAvail() "missing = \n%s----->", method, req->DebugString().c_str(), req->InitializationErrorString().c_str()); - qFatal("exiting"); + error = QString("RPC %1() missing required fields in request - %2") + .arg(QString::fromStdString( + service->GetDescriptor()->method( + pendingMethodId)->name()), + QString(req->InitializationErrorString().c_str())); delete req; delete resp; @@ -306,7 +341,14 @@ void RpcConnection::on_clientSock_dataAvail() _error_exit: inStream->Skip(len); _error_exit2: - qDebug("server(%s): discarding msg from client", __FUNCTION__); + qDebug("server(%s): return error %s for msg from client", __FUNCTION__, + qPrintable(error)); + pendingMethodId = method; + isPending = true; + controller = new PbRpcController(NULL, NULL); + controller->SetFailed(error); + controller->TriggerDisconnect(); + sendRpcReply(controller); return; } diff --git a/rpc/rpcconn.h b/rpc/rpcconn.h index 7ea581e..e41d8e2 100644 --- a/rpc/rpcconn.h +++ b/rpc/rpcconn.h @@ -68,6 +68,8 @@ private: bool isPending; int pendingMethodId; + + bool isCompatCheckDone; }; #endif diff --git a/server/myservice.cpp b/server/myservice.cpp index 825ab2c..55f3bf3 100644 --- a/server/myservice.cpp +++ b/server/myservice.cpp @@ -33,6 +33,11 @@ along with this program. If not, see #include "../rpc/pbrpccontroller.h" #include "portmanager.h" +#include + + +extern char *version; + MyService::MyService() { PortManager *portManager = PortManager::instance(); @@ -531,3 +536,46 @@ void MyService::clearStats(::google::protobuf::RpcController* /*controller*/, done->Run(); } + +void MyService::checkVersion(::google::protobuf::RpcController* controller, + const ::OstProto::VersionInfo* request, + ::OstProto::VersionCompatibility* response, + ::google::protobuf::Closure* done) +{ + QString myVersion(version); + QString clientVersion; + QStringList my, client; + + qDebug("In %s", __PRETTY_FUNCTION__); + + my = myVersion.split('.'); + + Q_ASSERT(my.size() >= 2); + + clientVersion = QString::fromStdString(request->version()); + client = clientVersion.split('.'); + + qDebug("client = %s, my = %s", + qPrintable(clientVersion), qPrintable(myVersion)); + + if (client.size() < 2) + goto _invalid_version; + + // Compare only major and minor numbers + if (client[0] == my[0] && client[1] == my[1]) { + response->set_result(OstProto::VersionCompatibility::kCompatible); + } + else { + response->set_result(OstProto::VersionCompatibility::kIncompatible); + response->set_notes(QString("Drone needs client version %1.%2.x") + .arg(my[0], my[1]).toStdString()); + static_cast(controller)->TriggerDisconnect(); + } + + done->Run(); + return; + +_invalid_version: + controller->SetFailed("invalid version information"); + done->Run(); +} diff --git a/server/myservice.h b/server/myservice.h index 557a30f..15c2f5f 100644 --- a/server/myservice.h +++ b/server/myservice.h @@ -97,6 +97,10 @@ public: const ::OstProto::PortIdList* request, ::OstProto::Ack* response, ::google::protobuf::Closure* done); + virtual void checkVersion(::google::protobuf::RpcController* controller, + const ::OstProto::VersionInfo* request, + ::OstProto::VersionCompatibility* response, + ::google::protobuf::Closure* done); private: /* diff --git a/test/rpctest.py b/test/rpctest.py index cbb3f13..5b341cb 100644 --- a/test/rpctest.py +++ b/test/rpctest.py @@ -63,6 +63,7 @@ class TestSuite: host_name = '127.0.0.1' tx_port_number = -1 rx_port_number = -1 +drone_version = ['0', '0', '0'] if sys.platform == 'win32': tshark = r'C:\Program Files\Wireshark\tshark.exe' @@ -91,7 +92,123 @@ drone = DroneProxy(host_name) try: # ----------------------------------------------------------------- # - # Baseline Configuration + # TESTCASE: Verify any RPC before checkVersion() fails and the server + # closes the connection + # ----------------------------------------------------------------- # + passed = False + suite.test_begin('anyRpcBeforeCheckVersionFails') + drone.channel.connect(drone.host, drone.port) + try: + port_id_list = drone.getPortIdList() + except RpcError as e: + if ('compatibility check pending' in str(e)): + passed = True + else: + raise + finally: + drone.channel.disconnect() + suite.test_end(passed) + + # ----------------------------------------------------------------- # + # TESTCASE: Verify DroneProxy.connect() fails for incompatible version + # ----------------------------------------------------------------- # + passed = False + suite.test_begin('connectFailsForIncompatibleVersion') + try: + drone.proxy_version = '0.1.1' + drone.connect() + except RpcError as e: + if ('needs client version' in str(e)): + passed = True + drone_version = str(e).split()[-1].split('.') + else: + raise + finally: + drone.proxy_version = None + suite.test_end(passed) + + # ----------------------------------------------------------------- # + # TESTCASE: Verify checkVersion() fails for invalid client version format + # ----------------------------------------------------------------- # + passed = False + suite.test_begin('checkVersionFailsForInvalidClientVersion') + try: + drone.proxy_version = '0-1-1' + drone.connect() + except RpcError as e: + if ('invalid version' in str(e)): + passed = True + else: + raise + finally: + drone.proxy_version = None + suite.test_end(passed) + + # ----------------------------------------------------------------- # + # TESTCASE: Verify checkVersion() returns incompatible if the 'major' + # part of the numbering format is + # different than the server's version and the server closes + # the connection + # ----------------------------------------------------------------- # + passed = False + suite.test_begin('checkVersionReturnsIncompatForDifferentMajorVersion') + try: + drone.proxy_version = (str(int(drone_version[0])+1) + + '.' + drone_version[1]) + drone.connect() + except RpcError as e: + #FIXME: How to check for a closed connection? + if ('needs client version' in str(e)): + passed = True + else: + raise + finally: + drone.proxy_version = None + suite.test_end(passed) + + # ----------------------------------------------------------------- # + # TESTCASE: Verify checkVersion() returns incompatible if the 'minor' + # part of the numbering format is + # different than the server's version and the server closes + # the connection + # ----------------------------------------------------------------- # + passed = False + suite.test_begin('checkVersionReturnsIncompatForDifferentMinorVersion') + try: + drone.proxy_version = (drone_version[0] + + '.' + str(int(drone_version[1])+1)) + drone.connect() + except RpcError as e: + #FIXME: How to check for a closed connection? + if ('needs client version' in str(e)): + passed = True + else: + raise + finally: + drone.proxy_version = None + suite.test_end(passed) + + # ----------------------------------------------------------------- # + # TESTCASE: Verify checkVersion() returns compatible if the 'revision' + # part of the numbering format is + # different than the server's version + # ----------------------------------------------------------------- # + passed = False + suite.test_begin('checkVersionReturnsCompatForDifferentRevisionVersion') + try: + drone.proxy_version = (drone_version[0] + + '.' + drone_version[1] + + '.' + '999') + drone.connect() + passed = True + except RpcError as e: + raise + finally: + drone.proxy_version = None + suite.test_end(passed) + + # ----------------------------------------------------------------- # + # Baseline Configuration for subsequent testcases # ----------------------------------------------------------------- # # connect to drone @@ -125,7 +242,7 @@ try: log.warning('loopback port not found') sys.exit(1) - print('Using port %d as tx/rx port(s)') + print('Using port %d as tx/rx port(s)' % tx_port_number) tx_port = ost_pb.PortIdList() tx_port.port_id.add().id = tx_port_number; @@ -176,6 +293,31 @@ try: drone.clearStats(tx_port) drone.clearStats(rx_port) + # ----------------------------------------------------------------- # + # TODO: + # TESTCASE: Verify a RPC with missing required fields in request fails + # and subsequently passes when the fields are initialized + # ----------------------------------------------------------------- # +# passed = False +# suite.test_begin('rpcWithMissingRequiredFieldsFails') +# pid = ost_pb.PortId() +# try: +# sid_list = drone.getStreamIdList(pid) +# except RpcError as e: +# if ('missing required fields in request' in str(e)): +# passed = True +# else: +# raise +# +# try: +# pid.id = tx_port_number +# sid_list = drone.getStreamIdList(pid) +# except RpcError as e: +# passed = False +# raise +# finally: +# suite.test_end(passed) + # ----------------------------------------------------------------- # # TESTCASE: Verify invoking addStream() during transmit fails # TESTCASE: Verify invoking modifyStream() during transmit fails