Skip to content

Commit

Permalink
Race condition in server when accepting a connection - fixes #94
Browse files Browse the repository at this point in the history
  • Loading branch information
vietj committed Jul 10, 2018
1 parent 3ea58c1 commit fdb8f73
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 71 deletions.
53 changes: 23 additions & 30 deletions src/main/java/io/vertx/mqtt/impl/MqttEndpointImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,16 @@ public boolean isAutoKeepAlive() {
}

public boolean isConnected() {
return this.isConnected;
synchronized (this.conn) {
return this.isConnected;
}
}

public MqttEndpoint setClientIdentifier(String clientIdentifier) {

this.clientIdentifier = clientIdentifier;
synchronized (this.conn) {
this.clientIdentifier = clientIdentifier;
}
return this;
}

Expand Down Expand Up @@ -309,27 +313,29 @@ private MqttEndpointImpl connack(MqttConnectReturnCode returnCode, boolean sessi

public MqttEndpointImpl accept(boolean sessionPresent) {

if (this.isConnected) {
throw new IllegalStateException("Connection already accepted");
}
synchronized (conn) {
if (this.isConnected) {
throw new IllegalStateException("Connection already accepted");
}

return this.connack(MqttConnectReturnCode.CONNECTION_ACCEPTED, sessionPresent);
return this.connack(MqttConnectReturnCode.CONNECTION_ACCEPTED, sessionPresent);
}
}

public MqttEndpointImpl reject(MqttConnectReturnCode returnCode) {

if (returnCode == MqttConnectReturnCode.CONNECTION_ACCEPTED) {
throw new IllegalArgumentException("Need to use the 'accept' method for accepting connection");
}
synchronized (conn) {
if (returnCode == MqttConnectReturnCode.CONNECTION_ACCEPTED) {
throw new IllegalArgumentException("Need to use the 'accept' method for accepting connection");
}

// sessionPresent flag has no meaning in this case, the network connection will be closed
return this.connack(returnCode, false);
// sessionPresent flag has no meaning in this case, the network connection will be closed
return this.connack(returnCode, false);
}
}

public MqttEndpointImpl subscribeAcknowledge(int subscribeMessageId, List<MqttQoS> grantedQoSLevels) {

this.checkConnected();

MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.SUBACK, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
Expand All @@ -346,8 +352,6 @@ public MqttEndpointImpl subscribeAcknowledge(int subscribeMessageId, List<MqttQo

public MqttEndpointImpl unsubscribeAcknowledge(int unsubscribeMessageId) {

this.checkConnected();

MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.UNSUBACK, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
Expand All @@ -362,8 +366,6 @@ public MqttEndpointImpl unsubscribeAcknowledge(int unsubscribeMessageId) {

public MqttEndpointImpl publishAcknowledge(int publishMessageId) {

this.checkConnected();

MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PUBACK, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
Expand All @@ -378,8 +380,6 @@ public MqttEndpointImpl publishAcknowledge(int publishMessageId) {

public MqttEndpointImpl publishReceived(int publishMessageId) {

this.checkConnected();

MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PUBREC, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
Expand All @@ -394,8 +394,6 @@ public MqttEndpointImpl publishReceived(int publishMessageId) {

public MqttEndpointImpl publishRelease(int publishMessageId) {

this.checkConnected();

MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PUBREL, false, MqttQoS.AT_LEAST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
Expand All @@ -410,8 +408,6 @@ public MqttEndpointImpl publishRelease(int publishMessageId) {

public MqttEndpointImpl publishComplete(int publishMessageId) {

this.checkConnected();

MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PUBCOMP, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttMessageIdVariableHeader variableHeader =
Expand All @@ -426,8 +422,6 @@ public MqttEndpointImpl publishComplete(int publishMessageId) {

public MqttEndpointImpl publish(String topic, Buffer payload, MqttQoS qosLevel, boolean isDup, boolean isRetain) {

this.checkConnected();

MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PUBLISH, isDup, qosLevel, isRetain, 0);
MqttPublishVariableHeader variableHeader =
Expand All @@ -444,8 +438,6 @@ public MqttEndpointImpl publish(String topic, Buffer payload, MqttQoS qosLevel,

public MqttEndpointImpl pong() {

this.checkConnected();

MqttFixedHeader fixedHeader =
new MqttFixedHeader(MqttMessageType.PINGRESP, false, MqttQoS.AT_MOST_ONCE, false, 0);

Expand Down Expand Up @@ -672,11 +664,12 @@ public SocketAddress remoteAddress() {
}
}

public MqttEndpointImpl write(io.netty.handler.codec.mqtt.MqttMessage mqttMessage) {
private void write(io.netty.handler.codec.mqtt.MqttMessage mqttMessage) {
synchronized (this.conn) {
this.checkClosed();
if (mqttMessage.fixedHeader().messageType() != MqttMessageType.CONNACK) {
this.checkConnected();
}
this.conn.writeMessage(mqttMessage);
return this;
}
}

Expand Down
93 changes: 53 additions & 40 deletions src/main/java/io/vertx/mqtt/impl/MqttServerConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public MqttServerConnection(NetSocketInternal so, MqttServerOptions options) {
*
* @param msg message to handle
*/
synchronized void handleMessage(Object msg) {
void handleMessage(Object msg) {

// handling directly native Netty MQTT messages, some of them are translated
// to the related Vert.x ones for polyglotization
Expand Down Expand Up @@ -195,13 +195,6 @@ private void handleConnect(MqttConnectMessage msg) {
return;
}

// if client sent one more CONNECT packet
if (endpoint != null) {
//we should treat it as a protocol violation and disconnect the client
endpoint.close();
return;
}

// retrieve will information from CONNECT message
MqttWill will =
new MqttWill(msg.variableHeader().isWillFlag(),
Expand Down Expand Up @@ -295,10 +288,12 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc
*
* @param msg message with subscribe information
*/
synchronized void handleSubscribe(MqttSubscribeMessage msg) {
void handleSubscribe(MqttSubscribeMessage msg) {

if (this.checkConnected()) {
this.endpoint.handleSubscribe(msg);
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handleSubscribe(msg);
}
}
}

Expand All @@ -307,10 +302,12 @@ synchronized void handleSubscribe(MqttSubscribeMessage msg) {
*
* @param msg message with unsubscribe information
*/
synchronized void handleUnsubscribe(MqttUnsubscribeMessage msg) {
void handleUnsubscribe(MqttUnsubscribeMessage msg) {

if (this.checkConnected()) {
this.endpoint.handleUnsubscribe(msg);
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handleUnsubscribe(msg);
}
}
}

Expand All @@ -319,10 +316,12 @@ synchronized void handleUnsubscribe(MqttUnsubscribeMessage msg) {
*
* @param msg published message
*/
synchronized void handlePublish(MqttPublishMessage msg) {
void handlePublish(MqttPublishMessage msg) {

if (this.checkConnected()) {
this.endpoint.handlePublish(msg);
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePublish(msg);
}
}
}

Expand All @@ -331,10 +330,12 @@ synchronized void handlePublish(MqttPublishMessage msg) {
*
* @param pubackMessageId identifier of the message acknowledged by the remote MQTT client
*/
synchronized void handlePuback(int pubackMessageId) {
void handlePuback(int pubackMessageId) {

if (this.checkConnected()) {
this.endpoint.handlePuback(pubackMessageId);
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePuback(pubackMessageId);
}
}
}

Expand All @@ -343,10 +344,12 @@ synchronized void handlePuback(int pubackMessageId) {
*
* @param pubrecMessageId identifier of the message acknowledged by the remote MQTT client
*/
synchronized void handlePubrec(int pubrecMessageId) {
void handlePubrec(int pubrecMessageId) {

if (this.checkConnected()) {
this.endpoint.handlePubrec(pubrecMessageId);
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePubrec(pubrecMessageId);
}
}
}

Expand All @@ -355,10 +358,12 @@ synchronized void handlePubrec(int pubrecMessageId) {
*
* @param pubrelMessageId identifier of the message acknowledged by the remote MQTT client
*/
synchronized void handlePubrel(int pubrelMessageId) {
void handlePubrel(int pubrelMessageId) {

if (this.checkConnected()) {
this.endpoint.handlePubrel(pubrelMessageId);
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePubrel(pubrelMessageId);
}
}
}

Expand All @@ -367,30 +372,36 @@ synchronized void handlePubrel(int pubrelMessageId) {
*
* @param pubcompMessageId identifier of the message acknowledged by the remote MQTT client
*/
synchronized void handlePubcomp(int pubcompMessageId) {
void handlePubcomp(int pubcompMessageId) {

if (this.checkConnected()) {
this.endpoint.handlePubcomp(pubcompMessageId);
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePubcomp(pubcompMessageId);
}
}
}

/**
* Used internally for handling the pinreq from the remote MQTT client
*/
synchronized void handlePingreq() {
void handlePingreq() {

if (this.checkConnected()) {
this.endpoint.handlePingreq();
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handlePingreq();
}
}
}

/**
* Used for calling the disconnect handler when the remote MQTT client disconnects
*/
synchronized void handleDisconnect() {
void handleDisconnect() {

if (this.checkConnected()) {
this.endpoint.handleDisconnect();
synchronized (this.so) {
if (this.checkConnected()) {
this.endpoint.handleDisconnect();
}
}
}

Expand All @@ -401,11 +412,13 @@ synchronized void handleDisconnect() {
*/
private boolean checkConnected() {

if ((this.endpoint != null) && (this.endpoint.isConnected())) {
return true;
} else {
so.close();
throw new IllegalStateException("Received an MQTT packet from a not connected client (CONNECT not sent yet)");
synchronized (this.so) {
if ((this.endpoint != null) && (this.endpoint.isConnected())) {
return true;
} else {
so.close();
throw new IllegalStateException("Received an MQTT packet from a not connected client (CONNECT not sent yet)");
}
}
}
}
4 changes: 3 additions & 1 deletion src/main/java/io/vertx/mqtt/impl/MqttServerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ public MqttServer listen(int port, String host, Handler<AsyncResult<MqttServer>>
MqttServerConnection conn = new MqttServerConnection(soi, options);

soi.messageHandler(msg -> {
conn.handleMessage(msg);
synchronized (conn) {
conn.handleMessage(msg);
}
});

conn.init(h1, h2);
Expand Down

0 comments on commit fdb8f73

Please sign in to comment.