#region Copyright 2010-2014 by Roger Knapp, Licensed under the Apache License, Version 2.0 /* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #endregion using System; using System.Runtime.InteropServices; using System.Threading; using CSharpTest.Net.RpcLibrary.Interop; using CSharpTest.Net.RpcLibrary.Interop.Structs; namespace CSharpTest.Net.RpcLibrary { /// /// Provides server-side services for RPC /// public class RpcServerApi : IDisposable { /// The max limit of in-flight calls public const int MAX_CALL_LIMIT = 255; /// Use the default request limits public const int DEF_REQ_LIMIT = -1; private static readonly UsageCounter _listenerCount = new UsageCounter("RpcApi.Listener.{0}", System.Diagnostics.Process.GetCurrentProcess().Id); private bool _isListening; private int _maxCalls; /// The interface Id the service is using public readonly Guid IID; private readonly RpcHandle _handle; private RpcExecuteHandler _handler; /// /// Enables verbose logging of the RPC calls to the Trace window /// public static bool VerboseLogging { get { return Log.VerboseEnabled; } set { Log.VerboseEnabled = value; } } /// /// Constructs an RPC server for the given interface guid, the guid is used to identify multiple rpc /// servers/services within a single process. /// public RpcServerApi(Guid iid) : this(iid, MAX_CALL_LIMIT, DEF_REQ_LIMIT, false) { } /// /// Constructs an RPC server for the given interface guid, the guid is used to identify multiple rpc /// servers/services within a single process. /// public RpcServerApi(Guid iid, int maxCalls, int maxRequestBytes, bool allowAnonTcp) { IID = iid; _maxCalls = maxCalls; _handle = new RpcServerHandle(); // Guid.Empty to avoid registration of any interface allowing access to AddProtocol/AddAuthentication // without creating a channel if (!Guid.Empty.Equals(iid)) ServerRegisterInterface(_handle, IID, RpcEntryPoint, maxCalls, maxRequestBytes, allowAnonTcp); } /// /// Disposes of the server and stops listening if the server is currently listening /// public void Dispose() { _handler = null; StopListening(); _handle.Dispose(); } /// /// Used to ensure that the server is listening with a specific protocol type. Once invoked this /// can not be undone, and all RPC servers within the process will be available on that protocol /// public bool AddProtocol(RpcProtseq protocol, string endpoint, int maxCalls) { _maxCalls = Math.Max(_maxCalls, maxCalls); return ServerUseProtseqEp(protocol, maxCalls, endpoint); } /// /// Adds a type of authentication sequence that will be allowed for RPC connections to this process. /// public bool AddAuthentication(RpcAuthentication type) { return AddAuthentication(type, null); } /// /// Adds a type of authentication sequence that will be allowed for RPC connections to this process. /// public bool AddAuthentication(RpcAuthentication type, string serverPrincipalName) { return ServerRegisterAuthInfo(type, serverPrincipalName); } /// /// Starts the RPC listener for this instance, if this is the first RPC server instance the process /// starts listening on the registered protocols. /// public void StartListening() { if (_isListening) return; _listenerCount.Increment(ServerListen, _maxCalls); _isListening = true; } /// /// Stops listening for this instance, if this is the last instance to stop listening the process /// stops listening on all registered protocols. /// public void StopListening() { if (!_isListening) return; _isListening = false; _listenerCount.Decrement(ServerStopListening); } private uint RpcEntryPoint(IntPtr clientHandle, uint szInput, IntPtr input, out uint szOutput, out IntPtr output) { output = IntPtr.Zero; szOutput = 0; try { byte[] bytesIn = new byte[szInput]; Marshal.Copy(input, bytesIn, 0, bytesIn.Length); byte[] bytesOut; using (RpcClientInfo client = new RpcClientInfo(clientHandle)) { bytesOut = Execute(client, bytesIn); } if (bytesOut == null) { return (uint) RpcError.RPC_S_NOT_LISTENING; } szOutput = (uint) bytesOut.Length; output = RpcApi.Alloc(szOutput); Marshal.Copy(bytesOut, 0, output, bytesOut.Length); return (uint) RpcError.RPC_S_OK; } catch (Exception ex) { RpcApi.Free(output); output = IntPtr.Zero; szOutput = 0; Log.Error(ex); return (uint) RpcError.RPC_E_FAIL; } } /// /// Can be over-ridden in a derived class to handle the incomming RPC request, or you can /// subscribe to the OnExecute event. /// public virtual byte[] Execute(IRpcClientInfo client, byte[] input) { RpcExecuteHandler proc = _handler; if (proc != null) { return proc(client, input); } return null; } /// /// Allows a single subscription to this event to handle incomming requests rather than /// deriving from and overriding the Execute call. /// public event RpcExecuteHandler OnExecute { add { lock (this) { Check.Assert(_handler == null, "The interface id is already registered."); _handler = value; } } remove { lock (this) { Check.NotNull(value); if (_handler != null) Check.Assert( Object.ReferenceEquals(_handler.Target, value.Target) && Object.ReferenceEquals(_handler.Method, value.Method) ); _handler = null; } } } /// /// The delegate format for the OnExecute event /// public delegate byte[] RpcExecuteHandler(IRpcClientInfo client, byte[] input); /* ******************************************************************** * WinAPI INTEROP * *******************************************************************/ private class RpcServerHandle : RpcHandle { protected override void DisposeHandle(ref IntPtr handle) { if (handle != IntPtr.Zero) { RpcServerUnregisterIf(handle, IntPtr.Zero, 1); handle = IntPtr.Zero; } } } [DllImport("Rpcrt4.dll", EntryPoint = "RpcServerUnregisterIf", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] private static extern RpcError RpcServerUnregisterIf(IntPtr IfSpec, IntPtr MgrTypeUuid, uint WaitForCallsToComplete); #region RpcServerXXXX routines [DllImport("Rpcrt4.dll", EntryPoint = "RpcServerUseProtseqEpW", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] private static extern RpcError RpcServerUseProtseqEp(String Protseq, int MaxCalls, String Endpoint, IntPtr SecurityDescriptor); private static bool ServerUseProtseqEp(RpcProtseq protocol, int maxCalls, String endpoint) { Log.Verbose("ServerUseProtseqEp({0})", protocol); RpcError err = RpcServerUseProtseqEp(protocol.ToString(), maxCalls, endpoint, IntPtr.Zero); if (err != RpcError.RPC_S_DUPLICATE_ENDPOINT) RpcException.Assert(err); return err == RpcError.RPC_S_OK; } delegate int RPC_IF_CALLBACK_FN(IntPtr Interface, IntPtr Context); private static readonly FunctionPtr hAuthCall = new FunctionPtr(AuthCall); static int AuthCall(IntPtr Interface, IntPtr Context) { return (int)RpcError.RPC_S_OK; } [DllImport("Rpcrt4.dll", EntryPoint = "RpcServerRegisterIf2", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] private static extern RpcError RpcServerRegisterIf2(IntPtr IfSpec, IntPtr MgrTypeUuid, IntPtr MgrEpv, int Flags, int MaxCalls, int MaxRpcSize, IntPtr hProc); [DllImport("Rpcrt4.dll", EntryPoint = "RpcServerRegisterIf", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] private static extern RpcError RpcServerRegisterIf(IntPtr IfSpec, IntPtr MgrTypeUuid, IntPtr MgrEpv); private static void ServerRegisterInterface(RpcHandle handle, Guid iid, RpcExecute fnExec, int maxCalls, int maxRequestBytes, bool allowAnonTcp) { const int RPC_IF_ALLOW_CALLBACKS_WITH_NO_AUTH = 0x0010; int flags = 0; IntPtr fnAuth = IntPtr.Zero; if (allowAnonTcp) { flags = RPC_IF_ALLOW_CALLBACKS_WITH_NO_AUTH; fnAuth = hAuthCall.Handle; } Ptr sIf = MIDL_SERVER_INFO.Create(handle, iid, RpcApi.TYPE_FORMAT, RpcApi.FUNC_FORMAT, fnExec); if (!allowAnonTcp && maxRequestBytes < 0) RpcException.Assert(RpcServerRegisterIf(sIf.Handle, IntPtr.Zero, IntPtr.Zero)); else RpcException.Assert(RpcServerRegisterIf2(sIf.Handle, IntPtr.Zero, IntPtr.Zero, flags, maxCalls <= 0 ? MAX_CALL_LIMIT : maxCalls, maxRequestBytes <= 0 ? 80 * 1024 : maxRequestBytes, fnAuth)); handle.Handle = sIf.Handle; } [DllImport("Rpcrt4.dll", EntryPoint = "RpcServerRegisterAuthInfoW", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] private static extern RpcError RpcServerRegisterAuthInfo(String ServerPrincName, uint AuthnSvc, IntPtr GetKeyFn, IntPtr Arg); private static bool ServerRegisterAuthInfo(RpcAuthentication auth, string serverPrincName) { Log.Verbose("ServerRegisterAuthInfo({0})", auth); RpcError response = RpcServerRegisterAuthInfo(serverPrincName, (uint) auth, IntPtr.Zero, IntPtr.Zero); if (response != RpcError.RPC_S_OK) { Log.Warning("ServerRegisterAuthInfo - unable to register authentication type {0}", auth); return false; } return true; } #endregion #region RpcServerListen & StopListening [DllImport("Rpcrt4.dll", EntryPoint = "RpcServerListen", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] private static extern RpcError RpcServerListen(uint MinimumCallThreads, int MaxCalls, uint DontWait); private static void ServerListen(int maxCalls) { Log.Verbose("Begin Server Listening"); RpcError result = RpcServerListen(1, maxCalls, 1); if (result == RpcError.RPC_S_ALREADY_LISTENING) { result = RpcError.RPC_S_OK; } RpcException.Assert(result); Log.Verbose("Server Ready"); } [DllImport("Rpcrt4.dll", EntryPoint = "RpcMgmtStopServerListening", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] private static extern RpcError RpcMgmtStopServerListening(IntPtr ignore); [DllImport("Rpcrt4.dll", EntryPoint = "RpcMgmtWaitServerListen", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] private static extern RpcError RpcMgmtWaitServerListen(); private static void ServerStopListening() { Log.Verbose("Stop Server Listening"); RpcError result = RpcMgmtStopServerListening(IntPtr.Zero); if (result != RpcError.RPC_S_OK) { Log.Warning("RpcMgmtStopServerListening result = {0}", result); } result = RpcMgmtWaitServerListen(); if (result != RpcError.RPC_S_OK) { Log.Warning("RpcMgmtWaitServerListen result = {0}", result); } } #endregion } }