#region Copyright 2011-2014 by Roger Knapp, Licensed under the Apache License, Version 2.0 /* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #endregion using System; using System.IO; using System.Threading; using CSharpTest.Net.Serialization; using CSharpTest.Net.IO; namespace CSharpTest.Net.Crypto { public static partial class SecureTransfer { #region Received EventArgs /// Event args that provides details about the start of a transfer public class BeginTransferEventArgs : EventArgs { private readonly Guid _transferId; private readonly string _location; private long _totalSize; internal BeginTransferEventArgs(Guid transferId, string location, long totalSize) { _transferId = transferId; _location = location; _totalSize = totalSize; } /// The client-provided unique identifier for this transfer public Guid TransferId { get { return _transferId; } } /// The client-provided name of the transfer public string Location { get { return _location; } } /// The full length of the file being transferred public long TotalSize { get { return _totalSize; } protected set { _totalSize = value; } } } /// Event args that provides details about the contents of a transfer public class BytesReceivedEventArgs : BeginTransferEventArgs { private readonly long _writeOffset; private readonly byte[] _bytesReceived; internal BytesReceivedEventArgs(Guid transferId, string location, long totalSize, long writeOffset, byte[] bytesReceived) : base(transferId, location, totalSize) { _writeOffset = writeOffset; _bytesReceived = bytesReceived; } /// The offset at which BytesReceived should be written public long WriteOffset { get { return _writeOffset; } } /// The bytes that should be written at WriteOffset public byte[] BytesReceived { get { return _bytesReceived; } } } /// Event args that provides details about the completion of a transfer public class CompleteTransferEventArgs : BeginTransferEventArgs { private readonly Hash _contentHash; internal CompleteTransferEventArgs(Guid transferId, string location, long totalSize, Hash contentHash) : base(transferId, location, totalSize) { _contentHash = contentHash; } /// The SHA-256 hash of the entire content file transferred public Hash ContentHash { get { return _contentHash; } } } /// Event args that provides details about the contents of a transfer public class DownloadBytesEventArgs : BeginTransferEventArgs { private readonly long _readOffset; private readonly int _readLength; private byte[] _bytesRead; internal DownloadBytesEventArgs(Guid transferId, string location, long readOffset, int readLength) : base(transferId, location, -1) { _readOffset = readOffset; _readLength = readLength; _bytesRead = null; } /// The offset at which the bytes should be read public long ReadOffset { get { return _readOffset; } } /// The number of bytes to read public int ReadLength { get { return _readLength; } } /// /// Returns the number of bytes specified in ReadLenght from the offset of ReadOffset /// public void SetBytes(long totalSize, byte[] bytesRead) { TotalSize = totalSize; _bytesRead = bytesRead; } internal void WriteTo(Stream output) { if (ReadLength > 0) { Check.Assert(_bytesRead != null && _bytesRead.Length == ReadLength); output.Write(_bytesRead, 0, ReadLength); } } } #endregion /// /// Provides a file transfer handler for the server-side (receiver) of file transfers. /// public class Server { private readonly RSAPrivateKey _privateKey; private readonly RSAPublicKey _clientKey; private readonly INameValueStore _storage; /// /// Constructs/reconstructs the server-side receiver to process one or more messages. This class /// maintains all state in the INameValueStore so it may be destroyed between requests, or there /// may be multiple instances handling requests, provided that all instances have access to the /// underlying storage provided by the INameValueStore instance. /// /// The private key used for this server /// The public key of the client to allow /// The state storage used between requests public Server(RSAPrivateKey privateKey, RSAPublicKey clientKey, INameValueStore storage) { _privateKey = privateKey; _clientKey = clientKey; _storage = storage; NonceSize = 32; KeyBytes = 32; MaxInboundFileChunk = ushort.MaxValue; MaxOutboundFileChunk = 1000*1024; } /// /// The amount of random data returned from the server to generate a session key /// protected int KeyBytes { get; set; } /// The number of random bytes to use for a nonce public int NonceSize { get; set; } /// /// The maximum number of bytes from the file to send, the actual message size will be longer by /// 100 or so bytes + SHA256 signature length (privateKey.ExportParameters().Modulus.Length). /// To be certain a client does not exceed a specific size, allow for an addition 2500 bytes. /// public int MaxInboundFileChunk { get; set; } /// /// The maximum number of bytes from the file to send, the actual message size will be longer by /// 100 or so bytes + SHA256 signature length (privateKey.ExportParameters().Modulus.Length). /// To be certain a client does not exceed a specific size, allow for an addition 2500 bytes. /// public int MaxOutboundFileChunk { get; set; } #region State Management /// returns true if the value exists protected virtual bool HasState(Guid transferId, string name) { string value; return _storage.Read(transferId.ToString("N"), name, out value); } /// returns the value identified protected virtual string ReadState(Guid transferId, string name) { string value; if (!_storage.Read(transferId.ToString("N"), name, out value)) throw new InvalidDataException(); return value; } /// stores the value identified protected virtual void WriteState(Guid transferId, string name, string value) { _storage.Write(transferId.ToString("N"), name, value); } /// removes the value identified protected virtual void DeleteState(Guid transferId, string name) { _storage.Delete(transferId.ToString("N"), name); } /// removes all values for a give transfer protected virtual void Delete(Guid transferId) { _storage.Delete(transferId.ToString("N"), "start-time"); _storage.Delete(transferId.ToString("N"), "nonce"); _storage.Delete(transferId.ToString("N"), "session-key"); _storage.Delete(transferId.ToString("N"), "total-length"); _storage.Delete(transferId.ToString("N"), "location"); } #endregion private Salt SessionKey(Guid transferId) { return Salt.FromString(ReadState(transferId, "session-key")); } #region Events /// Raised when an error occurs public event ErrorEventHandler ErrorRaised; /// Raised when a transfer begins public event EventHandler BeginTransfer; /// Raised when bytes are received public event EventHandler BytesReceived; /// Raised when a transfer completes public event EventHandler CompleteTransfer; /// Raised durring a download request public event EventHandler DownloadBytes; private void OnErrorRaised(Exception error) { ErrorEventHandler handler = ErrorRaised; if (handler != null) handler(this, new ErrorEventArgs(error)); } private void OnBeginTransfer(Guid transferId, string destanation, long length) { if (BeginTransfer != null) BeginTransfer(this, new BeginTransferEventArgs(transferId, destanation, length)); } private void OnBytesReceived(Guid transferId, string destanation, long totalSize, long offset, byte[] bytes) { if (BytesReceived != null) BytesReceived(this, new BytesReceivedEventArgs(transferId, destanation, totalSize, offset, bytes)); } private void OnCompleteTransfer(Guid transferId, string location, long totalSize, Hash contentHash) { if (CompleteTransfer != null) CompleteTransfer(this, new CompleteTransferEventArgs(transferId, location, totalSize, contentHash)); } private void OnDownloadBytes(Guid transferId, string location, out long totalSize, long offset, int length, Stream output) { DownloadBytesEventArgs args = new DownloadBytesEventArgs(transferId, location, offset, length); if (DownloadBytes != null) DownloadBytes(this, args); Check.Assert(args.TotalSize >= 0); totalSize = args.TotalSize; args.WriteTo(output); } #endregion /// /// Processes an inbound message and returns the result /// /// Raised for any internal error public Stream Receive(Stream data) { try { using (Message req = new Message(data, _privateKey, SessionKey)) { VerifyMesage(req); switch (req.State) { case TransferState.NonceRequest: return NonceRequest(req).ToStream(_privateKey); case TransferState.UploadRequest: return TransferRequest(req).ToStream(_privateKey); case TransferState.SendBytesRequest: return SendBytesRequest(req).ToStream(_privateKey); case TransferState.UploadCompleteRequest: return CompleteRequest(req).ToStream(_privateKey); case TransferState.DownloadRequest: return DownloadRequest(req).ToStream(_privateKey); case TransferState.DownloadBytesRequest: return DownloadBytesRequest(req).ToStream(_privateKey); case TransferState.DownloadCompleteRequest: return DownloadCompleteRequest(req).ToStream(_privateKey); default: throw new InvalidDataException(); } } } catch (Exception error) { OnErrorRaised(error); Thread.Sleep(new Random().Next(10, 100)); throw new InvalidDataException(); } } private void VerifyMesage(Message msg) { Check.Assert(HasState(msg.TransferId, "start-time") == (msg.State != TransferState.NonceRequest)); if (msg.State > TransferState.NonceRequest) { long startUtcTicks = long.Parse(ReadState(msg.TransferId, "start-time")); DateTime started = new DateTime(startUtcTicks, DateTimeKind.Utc); if ((DateTime.UtcNow - started).TotalHours > 2 || (msg.State == TransferState.UploadRequest && (DateTime.UtcNow - started).TotalMinutes > 2)) { Delete(msg.TransferId); throw new TimeoutException(); } } if (msg.State >= TransferState.StartSessionKey) Check.Assert(HasState(msg.TransferId, "session-key")); } private Message NonceRequest(Message req) { req.VerifySignature(_clientKey); Check.Assert(HasState(req.TransferId, "start-time") == false); WriteState(req.TransferId, "start-time", DateTime.UtcNow.Ticks.ToString()); byte[] nonce = new byte[NonceSize]; new Random().NextBytes(nonce); WriteState(req.TransferId, "nonce", Convert.ToBase64String(nonce)); byte[] keydata = new byte[KeyBytes]; new System.Security.Cryptography.RNGCryptoServiceProvider().GetBytes(keydata); WriteState(req.TransferId, "server-key", Convert.ToBase64String(keydata)); Message response = new Message(TransferState.NonceResponse, req.TransferId, _clientKey, NoSession); response.Write(nonce); response.Write(keydata); return response; } private static Salt SessionSecret(byte[] clientKeyBits, byte[] serverKeyBits) { Salt sessionSecret = Salt.FromBytes( Hash.SHA256( new CombinedStream( new MemoryStream(clientKeyBits, false), new MemoryStream(serverKeyBits, false) ) ).ToArray() ); return sessionSecret; } private Message TransferRequest(Message req) { Check.Assert(HasState(req.TransferId, "start-time") && HasState(req.TransferId, "nonce") && HasState(req.TransferId, "server-key") && !HasState(req.TransferId, "session-key")); byte[] nonceExpected = Convert.FromBase64String(ReadState(req.TransferId, "nonce")); DeleteState(req.TransferId, "nonce"); byte[] serverKeyBits = Convert.FromBase64String(ReadState(req.TransferId, "server-key")); DeleteState(req.TransferId, "server-key"); byte[] hnonce = req.ReadBytes(32); long length = req.ReadInt64(); string name = req.ReadString(1024); byte[] clientKeyBits = req.ReadBytes(32); req.VerifySignature(_clientKey); Check.Assert( Hash.SHA256(nonceExpected).Equals(Hash.FromBytes(hnonce)) ); OnBeginTransfer(req.TransferId, name, length); WriteState(req.TransferId, "total-length", length.ToString()); WriteState(req.TransferId, "location", name ?? String.Empty); WriteState(req.TransferId, "session-key", SessionSecret(clientKeyBits, serverKeyBits).ToString()); Message response = new Message(TransferState.UploadResponse, req.TransferId, _clientKey, SessionKey); response.Write(MaxInboundFileChunk); return response; } private Message SendBytesRequest(Message req) { long position = req.ReadInt64(); byte[] bytes = req.ReadBytes(MaxInboundFileChunk); req.VerifySignature(_clientKey); string location = ReadState(req.TransferId, "location"); long totalSize = long.Parse(ReadState(req.TransferId, "total-length")); Check.InRange(position, 0, totalSize - bytes.Length); OnBytesReceived(req.TransferId, location, totalSize, position, bytes); Message response = new Message(TransferState.SendBytesResponse, req.TransferId, _clientKey, SessionKey); response.Write(position); return response; } private Message CompleteRequest(Message req) { try { string name = req.ReadString(1024); Hash contentHash = Hash.FromBytes(req.ReadBytes(32)); req.VerifySignature(_clientKey); long totalSize = long.Parse(ReadState(req.TransferId, "total-length")); string location = ReadState(req.TransferId, "location"); Check.Assert(location == name); OnCompleteTransfer(req.TransferId, location, totalSize, contentHash); Message response = new Message(TransferState.UploadCompleteResponse, req.TransferId, _clientKey, SessionKey); return response; } finally { Delete(req.TransferId); } } private Message DownloadRequest(Message req) { Check.Assert(HasState(req.TransferId, "start-time") && HasState(req.TransferId, "nonce") && HasState(req.TransferId, "server-key") && !HasState(req.TransferId, "session-key")); byte[] nonceExpected = Convert.FromBase64String(ReadState(req.TransferId, "nonce")); DeleteState(req.TransferId, "nonce"); byte[] serverKeyBits = Convert.FromBase64String(ReadState(req.TransferId, "server-key")); DeleteState(req.TransferId, "server-key"); byte[] hnonce = req.ReadBytes(32); string name = req.ReadString(1024); byte[] clientKeyBits = req.ReadBytes(32); req.VerifySignature(_clientKey); Check.Assert( Hash.SHA256(nonceExpected).Equals(Hash.FromBytes(hnonce)) ); Salt sessionKey = SessionSecret(clientKeyBits, serverKeyBits); long length; OnDownloadBytes(req.TransferId, name, out length, 0, 0, Stream.Null); Message response = new Message(TransferState.DownloadResponse, req.TransferId, _clientKey, s => sessionKey); response.Write(MaxOutboundFileChunk); response.Write(length); if(length <= MaxOutboundFileChunk) { Delete(req.TransferId); using (MemoryStream ms = new MemoryStream()) { OnDownloadBytes(req.TransferId, name, out length, 0, (int)length, ms); Check.Assert(ms.Position == length); response.Write(ms.ToArray()); } } else { WriteState(req.TransferId, "total-length", length.ToString()); WriteState(req.TransferId, "location", name ?? String.Empty); WriteState(req.TransferId, "session-key", sessionKey.ToString()); response.Write(new byte[0]); } return response; } private Message DownloadBytesRequest(Message req) { string location = req.ReadString(1024); long position = req.ReadInt64(); int count = req.ReadInt32(); req.VerifySignature(_clientKey); Check.Assert(location == ReadState(req.TransferId, "location")); long totalSize = long.Parse(ReadState(req.TransferId, "total-length")); Check.InRange(position, 0, totalSize); Check.InRange(position + count, 0, totalSize); Message response = new Message(TransferState.DownloadBytesResponse, req.TransferId, _clientKey, SessionKey); using (MemoryStream ms = new MemoryStream()) { long length; OnDownloadBytes(req.TransferId, location, out length, position, count, ms); Check.Assert(length == totalSize); Check.Assert(ms.Position == count); response.Write(ms.ToArray()); } return response; } private Message DownloadCompleteRequest(Message req) { req.VerifySignature(_clientKey); Delete(req.TransferId); return Message.EmptyMessage; } } } }