mercredi 1 septembre 2010

A horrible hack turned safe thanks to templates.

While coding some shader implementation for DirectX 11 I found myself writing the same code over and over for each shader type (Vertex, Fragment, Geometry, Hull, Domain etc...) and I figured I should make it more generic.

The only problem is that Direct3D 11 has a separate function to create each type of shader, so in order to make it generic I originally came up with a solution which I qualify as "horrible hack".

In DirectX 11 the ID3D11Device has some methods like:
HRESULT __stdcall ID3D11Device::CreateVertexShader(
    [in]   const void *pShaderBytecode,
    [in]   SIZE_T BytecodeLength,
    [in]   ID3D11ClassLinkage *pClassLinkage,
    [out]  ID3D11VertexShader **ppVertexShader);


    HRESULT __stdcall ID3D11Device::CreatePixelShader(
    [in]   const void *pShaderBytecode,
    [in]   SIZE_T BytecodeLength,
    [in]   ID3D11ClassLinkage *pClassLinkage,
    [out]  ID3D11PixelShader **ppPixelShader
    );

And ID3D11VertexShader and ID3D11PixelShader are both inheriting publicly from ID3D11DeviceChild:

ID3D11VertexShader : public ID3D11DeviceChild { ... }
    ID3D11PixelShader : public ID3D11DeviceChild { ... }

Now the trick was to define a function pointer CreateShader that has the following prototype:

typedef HRESULT (__stdcall ID3D11Device::*CreateShader)(const void *,SIZE_T, ID3D11ClassLinkage*,ID3D11DeviceChild**);

The only difference is the last argument which points to the superclass of the shader classes.
Now I only have to pass this function pointer pointing to the right method of the ID3D11Device when I create my shader.
The problem with that is that I have to pass the shader as its base type, which is supposed to work according to the rules of C++ except, however the (wise) compiler complained when I tried to do it implicitly.
That forced me to create a ID3D11DeviceChild* and assign the ID3D11*Shader to it, then pass it to the method.
In other word it look even more horrible; that's where templates come to help.

I decided to template the loadShader method, and use the shader type to select the appropriate ID3D11::Create* method.
Since function templates cannot have defaults, I created a helper structure:

template <class Shader>
struct CreateShaderHelper
{
 typedef typename HRESULT (__stdcall ID3D11Device::*FuncType)(const void *,SIZE_T, ClassLinkage*,Shader**);
};
Now I can re-write the loadShader method like this:

template <class Shader>
int loadShader(const char* name,
        const char* defines,
        const char* profile,
        typename CreateShaderHelper<Shader>::FuncType shader_creator,
        ConstantBuffer& cbuff,
        Shader*& outShader,
        ID3D11ShaderReflection*& outShaderReflect)
{
    ...
    if ((device->*shader_creator)(shader_buffer->GetBufferPointer(),shader_buffer->GetBufferSize(),&outShader) == S_OK)
    {
         ...
    }
}