Friday, October 23, 2009

ATL Removal: Part 3 – Tackling _Module and OBJECT_MAP

In Part 1, we saw how ATL uses com maps to automatically generate the IUnknown implementation and to manage the objects lifetime. In this installment we will tackle the object and lifetime management of the DLL.

Let’s dive into the guts of the DLL management. This is another place that ATL was intended to save time for the developer. If you have an ATL based DLL, you might notice something like this in your main DLL source file.
CComModule _Module;
BEGIN_OBJECT_MAP(ObjectMap)
OBJECT_ENTRY(CLSID_YourClass, CYourClass)
END_OBJECT_MAP()
The _Module object is designed to manage the lifetime of your DLL. You ATL CComObjects objects will automatically call _Module.Lock() and _Module.Unlock(). Awsome! If you use _Module in your ATL based DLL and some non ATL com objects, make sure the lock and unlock _Module in the constructors and destructors respectively; I fixed a bug last year where my DLL was prematurely getting unloaded because of this. The first thing to do to replace _Module is to add:
LONG g_cLockCount = 0;

inline VOID IncModuleCount()
{
InterlockedIncrement(&g_cLockCount);
} // IncModuleCount

inline VOID DecModuleCount()
{
InterlockedDecrement(&g_cLockCount);
} // DecModuleCount
You DLL can use that to keep track of the active objects so I can know when it is safe to unload.

Next let’s look at the object map. What does that buy for you? It basically provides you a free implementation of IClassFactory. Is it hard to write your own? No, I will show you how right now.
You can create a file called ClassFactory.h that looks like this:
class CClassFactory:
public IClassFactory
{
public:
// IUnknown
STDMETHODIMP_(ULONG) AddRef();
STDMETHODIMP_(ULONG) Release();
STDMETHODIMP QueryInterface(
REFIID riid,
__deref_out_opt void **ppv);

// IClassFactory
STDMETHODIMP CreateInstance(
__in_opt IUnknown *punkOuter,
REFIID iid,
__deref_out_opt void **ppv);

STDMETHODIMP LockServer(
BOOL fLock);

// Constructor / Destuctor
CClassFactory();
~CClassFactory();

protected:
LONG m_cRef;
}; // CClassFactory
Now, let’s look at the implementation. Here is the content of ClassFactory.cpp:
#include "stdafx.h"

//---------------------------------------------------------------------------
// Begin CClassFactory implemetation
//---------------------------------------------------------------------------
extern LONG g_cLockCount;

CClassFactory::CClassFactory():
m_cRef(1)
{
InterlockedIncrement(&g_cLockCount);
} // CClassFactory::CClassFactory

CClassFactory::~CClassFactory()
{
InterlockedDecrement(&g_cLockCount);
} // CClassFactory::~CClassFactory

STDMETHODIMP_(ULONG) CClassFactory::AddRef()
{
return InterlockedIncrement(&m_cRef);
} // CClassFactory::AddRef

STDMETHODIMP_(ULONG) CClassFactory::Release()
{
LONG cRef = InterlockedDecrement(&m_cRef);

if (!cRef)
{
delete this;
}

return cRef;
} // CClassFactory::Release

STDMETHODIMP CClassFactory::QueryInterface(
REFIID riid,
__deref_out_opt void **ppv)
{
HRESULT hr = S_OK;

if (ppv)
{
*ppv = NULL;
}
else
{
hr = E_INVALIDARG;
}

if (S_OK == hr)
{
if (IID_IUnknown == riid)
{
AddRef();
*ppv = (IUnknown*)(IClassFactory*)this;
}
else if (IID_IClassFactory == riid)
{
AddRef();
*ppv = (IClassFactory*)this;
}
else
{
hr = E_NOINTERFACE;
}
}

return hr;
} // CClassFactory::QueryInterface

STDMETHODIMP CClassFactory::CreateInstance(
__in_opt IUnknown *pUnkownOuter,
REFIID riid,
__deref_out_opt void **ppv)
{
HRESULT hr = S_OK;
IUnknown *pUnknown = NULL;

if (ppv)
{
*ppv = NULL;
}
else
{
hr = E_INVALIDARG;
}

if (S_OK == hr)
{
if (pUnkownOuter)
{
hr = CLASS_E_NOAGGREGATION;
}
}

if (S_OK == hr)
{
pUnknown = new(std::nothrow) CAudioProvider();

if (!pUnknown)
{
hr = E_OUTOFMEMORY;
}
}

if (S_OK == hr)
{
hr = pUnknown->QueryInterface(riid, ppv);
}

if (pUnknown)
{
pUnknown->Release();
}

return hr;
} // CClassFactory::CreateInstance

STDMETHODIMP CClassFactory::LockServer(
BOOL fLock)
{
if (fLock)
{
InterlockedIncrement(&g_cLockCount);
}
else
{
InterlockedDecrement(&g_cLockCount);
}

return S_OK;
} // CClassFactory::LockServer
//---------------------------------------------------------------------------
// End CClassFactory implemetation
//---------------------------------------------------------------------------
Next you will have to fix the rest of the DllMain cpp file. This is what a simple version would look like clean of ATL
extern "C"
{

BOOL APIENTRY DllMain(
HMODULE hModule,
ULONG ulReason,
__in_opt PVOID pReserved)
{
BOOL fRetVal = TRUE;

if (DLL_PROCESS_ATTACH == ulReason)
{
// Disable thread attach notifications
fRetVal = DisableThreadLibraryCalls(hModule);
}

return fRetVal;
} // DllMain

STDAPI DllGetClassObject(
__in REFCLSID rclsid,
__in REFIID riid,
__deref_out LPVOID FAR *ppv)
{
HRESULT hr = S_OK;
CClassFactory* pClassFactory = NULL;

if (ppv)
{
*ppv = NULL;
}
else
{
hr = E_INVALIDARG;
}

if (S_OK == hr)
{
if (CLSID_fdAudio != rclsid)
{
hr = CLASS_E_CLASSNOTAVAILABLE;
}
}

if (S_OK == hr)
{
pClassFactory = new(std::nothrow) CClassFactory;

if (!pClassFactory)
{
hr = E_OUTOFMEMORY;
}
}

if (S_OK == hr)
{
hr = pClassFactory->QueryInterface(riid, ppv);
}

if (pClassFactory)
{
pClassFactory->Release();
}

return hr;
} // DLLGetClassObject

HRESULT APIENTRY DllCanUnloadNow()
{
return (g_cLockCount == 0) ? S_OK : S_FALSE;
} // DllCanUnloadNow

} // extern "C"
Finally you will have to fix your ATL free com objects in their constructors to increment and decrement the DLL lock count. By calling these:
inline VOID IncModuleCount()
{
InterlockedIncrement(&g_cLockCount);
} // IncModuleCount

inline VOID DecModuleCount()
{
InterlockedDecrement(&g_cLockCount);
} // DecModuleCount
Hopefully this will get you one step closer to being ATL free.

No comments:

Post a Comment