mercredi 1 septembre 2010

A horrible hack turned safe thanks to templates.

While coding some shader implementation for DirectX 11 I found myself writing the same code over and over for each shader type (Vertex, Fragment, Geometry, Hull, Domain etc...) and I figured I should make it more generic.

The only problem is that Direct3D 11 has a separate function to create each type of shader, so in order to make it generic I originally came up with a solution which I qualify as "horrible hack".

In DirectX 11 the ID3D11Device has some methods like:
HRESULT __stdcall ID3D11Device::CreateVertexShader(
    [in]   const void *pShaderBytecode,
    [in]   SIZE_T BytecodeLength,
    [in]   ID3D11ClassLinkage *pClassLinkage,
    [out]  ID3D11VertexShader **ppVertexShader);

    HRESULT __stdcall ID3D11Device::CreatePixelShader(
    [in]   const void *pShaderBytecode,
    [in]   SIZE_T BytecodeLength,
    [in]   ID3D11ClassLinkage *pClassLinkage,
    [out]  ID3D11PixelShader **ppPixelShader

And ID3D11VertexShader and ID3D11PixelShader are both inheriting publicly from ID3D11DeviceChild:

ID3D11VertexShader : public ID3D11DeviceChild { ... }
    ID3D11PixelShader : public ID3D11DeviceChild { ... }

Now the trick was to define a function pointer CreateShader that has the following prototype:

typedef HRESULT (__stdcall ID3D11Device::*CreateShader)(const void *,SIZE_T, ID3D11ClassLinkage*,ID3D11DeviceChild**);

The only difference is the last argument which points to the superclass of the shader classes.
Now I only have to pass this function pointer pointing to the right method of the ID3D11Device when I create my shader.
The problem with that is that I have to pass the shader as its base type, which is supposed to work according to the rules of C++ except, however the (wise) compiler complained when I tried to do it implicitly.
That forced me to create a ID3D11DeviceChild* and assign the ID3D11*Shader to it, then pass it to the method.
In other word it look even more horrible; that's where templates come to help.

I decided to template the loadShader method, and use the shader type to select the appropriate ID3D11::Create* method.
Since function templates cannot have defaults, I created a helper structure:

template <class Shader>
struct CreateShaderHelper
 typedef typename HRESULT (__stdcall ID3D11Device::*FuncType)(const void *,SIZE_T, ClassLinkage*,Shader**);
Now I can re-write the loadShader method like this:

template <class Shader>
int loadShader(const char* name,
        const char* defines,
        const char* profile,
        typename CreateShaderHelper<Shader>::FuncType shader_creator,
        ConstantBuffer& cbuff,
        Shader*& outShader,
        ID3D11ShaderReflection*& outShaderReflect)
    if ((device->*shader_creator)(shader_buffer->GetBufferPointer(),shader_buffer->GetBufferSize(),&outShader) == S_OK)

3 commentaires:

  1. I would just overload
    loadShader( ID3D11VertexShader** ... )
    loadShader( ID3D11PixelShader** ... )

  2. Yeah but then you need to re-write a lot of similar code for each shader type.
    With my template function, I only write the code _once_ (in loadShader), the compiler then specializes it for me.
    So all my code that deals with loading shaders is in one function, and very generic which makes it very maintainable. Any bug fix that I do in this function will apply to all shader type.