Skip to content

Commit

Permalink
Improved callback support added Task support and passing SessionRecord
Browse files Browse the repository at this point in the history
    This modifies the existing callback support to work with Task based callbacks.  It also passes the SessionRecord (when available) to the callback (as not yet committed for prekey), needed for protocol version check code that already exists).
  • Loading branch information
mitchcapper committed Aug 17, 2018
1 parent 8cbd533 commit 87c1b1f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 30 deletions.
37 changes: 19 additions & 18 deletions libsignal-protocol-dotnet/DecryptionCallback.cs
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
/**
* Copyright (C) 2016 smndtrl, langboost
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

using libsignal.state;
using System.Threading.Tasks;
/**
* Copyright (C) 2016 smndtrl, langboost
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
namespace libsignal
{
public interface DecryptionCallback
{
void handlePlaintext(byte[] plaintext);
Task handlePlaintext(byte[] plaintext, SessionRecord sessionRecord);
}
}
23 changes: 14 additions & 9 deletions libsignal-protocol-dotnet/SessionCipher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
using Strilanc.Value;
using System;
using System.Collections.Generic;
using System.Threading.Tasks;

namespace libsignal
{
Expand Down Expand Up @@ -141,7 +142,9 @@ public CiphertextMessage encrypt(byte[] paddedMessage)
*/
public byte[] decrypt(PreKeySignalMessage ciphertext)
{
return decrypt(ciphertext, new NullDecryptionCallback());
var tsk = (decrypt(ciphertext, new NullDecryptionCallback()));
tsk.Wait();
return tsk.Result;
}

/**
Expand All @@ -165,7 +168,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext)
* @throws InvalidKeyException when the message is formatted incorrectly.
* @throws UntrustedIdentityException when the {@link IdentityKey} of the sender is untrusted.
*/
public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callback)
public Task<byte[]> decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callback)
{
lock (SESSION_LOCK)
{
Expand All @@ -175,7 +178,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac

identityKeyStore.SaveIdentity(remoteAddress, sessionRecord.getSessionState().getRemoteIdentityKey());

callback.handlePlaintext(plaintext);
callback.handlePlaintext(plaintext, sessionRecord).Wait();

sessionStore.StoreSession(remoteAddress, sessionRecord);

Expand All @@ -184,7 +187,7 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac
preKeyStore.RemovePreKey(unsignedPreKeyId.ForceGetValue());
}

return plaintext;
return Task.FromResult(plaintext);
}
}

Expand All @@ -202,7 +205,9 @@ public byte[] decrypt(PreKeySignalMessage ciphertext, DecryptionCallback callbac
*/
public byte[] decrypt(SignalMessage ciphertext)
{
return decrypt(ciphertext, new NullDecryptionCallback());
var tsk = decrypt(ciphertext, new NullDecryptionCallback());
tsk.Wait();
return tsk.Result;
}

/**
Expand All @@ -223,7 +228,7 @@ public byte[] decrypt(SignalMessage ciphertext)
* is no longer supported.
* @throws NoSessionException if there is no established session for this contact.
*/
public byte[] decrypt(SignalMessage ciphertext, DecryptionCallback callback)
public Task<byte[]> decrypt(SignalMessage ciphertext, DecryptionCallback callback)
{
lock (SESSION_LOCK)
{
Expand All @@ -241,11 +246,11 @@ public byte[] decrypt(SignalMessage ciphertext, DecryptionCallback callback)
throw new UntrustedIdentityException(remoteAddress.Name, sessionRecord.getSessionState().getRemoteIdentityKey());
}

callback.handlePlaintext(plaintext);
callback.handlePlaintext(plaintext, sessionRecord).Wait();//no async in a lock

sessionStore.StoreSession(remoteAddress, sessionRecord);

return plaintext;
return Task.FromResult(plaintext);
}
}

Expand Down Expand Up @@ -425,7 +430,7 @@ private byte[] getPlaintext(MessageKeys messageKeys, byte[] cipherText)
private class NullDecryptionCallback : DecryptionCallback
{

public void handlePlaintext(byte[] plaintext) { }
public Task handlePlaintext(byte[] plaintext, SessionRecord sessionRecord) => Task.CompletedTask;
}
}
}
9 changes: 6 additions & 3 deletions libsignal-protocol-dotnet/groups/GroupCipher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
using libsignal.groups.ratchet;
using libsignal.groups.state;
using libsignal.protocol;
using libsignal.state;
using libsignal.util;
using System;
using System.Threading.Tasks;

namespace libsignal.groups
{
Expand Down Expand Up @@ -134,7 +136,7 @@ public byte[] decrypt(byte[] senderKeyMessageBytes, DecryptionCallback callback)

byte[] plaintext = getPlainText(senderKey.getIv(), senderKey.getCipherKey(), senderKeyMessage.getCipherText());

callback.handlePlaintext(plaintext);
callback.handlePlaintext(plaintext, null).Wait();

senderKeyStore.storeSenderKey(senderKeyId, record);

Expand Down Expand Up @@ -217,8 +219,9 @@ private byte[] getCipherText(byte[] iv, byte[] key, byte[] plaintext)

private class NullDecryptionCallback : DecryptionCallback
{
public void handlePlaintext(byte[] plaintext) { }
}
public Task handlePlaintext(byte[] plaintext, SessionRecord sessionRecord) => Task.CompletedTask;

}

}
}

0 comments on commit 87c1b1f

Please sign in to comment.