#include "stdafx.h"

// Proxy.cpp: working where Proxy_Adv couldn't!

// Export definitions for supported wrappables

#pragma region XInput9_1_0
typedef DWORD(WINAPI* XInputGetCapabilities_ptr)(DWORD dwUserIndex, DWORD dwFlags, void* pCapabilities);
typedef DWORD(WINAPI* XInputGetDSoundAudioDeviceGuids_ptr)(DWORD dwUserIndex, GUID* pDSoundRenderGuid, GUID* pDSoundCaptureGuid);
typedef DWORD(WINAPI* XInputGetState_ptr)(DWORD dwUserIndex, void* pState);
typedef DWORD(WINAPI* XInputSetState_ptr)(DWORD dwUserIndex, void* pVibration);

XInputGetCapabilities_ptr XInputGetCapabilities_orig;
XInputGetDSoundAudioDeviceGuids_ptr XInputGetDSoundAudioDeviceGuids_orig;
XInputGetState_ptr XInputGetState_orig;
XInputSetState_ptr XInputSetState_orig;

EXPORT DWORD WINAPI XInputGetCapabilities(DWORD dwUserIndex, DWORD dwFlags, void* pCapabilities)
{
	return XInputGetCapabilities_orig(dwUserIndex, dwFlags, pCapabilities);
}

EXPORT DWORD WINAPI XInputGetDSoundAudioDeviceGuids(DWORD dwUserIndex, GUID* pDSoundRenderGuid, GUID* pDSoundCaptureGuid)
{
	return XInputGetDSoundAudioDeviceGuids_orig(dwUserIndex, pDSoundRenderGuid, pDSoundCaptureGuid);
}

EXPORT DWORD WINAPI XInputGetState(DWORD dwUserIndex, void* pState)
{
	return XInputGetState_orig(dwUserIndex, pState);
}

EXPORT DWORD WINAPI XInputSetState(DWORD dwUserIndex, void* pVibration)
{
	return XInputSetState_orig(dwUserIndex, pVibration);
}
#pragma endregion

HMODULE orig_module_ = nullptr;

// Updates main applications import table to point to original DLL functions instead
bool Proxy_UpdateImports(const char* LibraryName)
{
	auto* base = (BYTE*)GetModuleHandleA(NULL);
	auto* importTable = (IMAGE_IMPORT_DESCRIPTOR*)ModuleDirectoryEntryData(base, IMAGE_DIRECTORY_ENTRY_IMPORT);
	if (!importTable)
		return false;

	for (; importTable->Characteristics; ++importTable)
	{
		auto* dllName = (const char*)(base + importTable->Name);

		if (!_stricmp(dllName, LibraryName))
		{
			// found the dll

			auto* thunkData = (IMAGE_THUNK_DATA*)(base + importTable->OriginalFirstThunk);
			auto* iat = (FARPROC*)(base + importTable->FirstThunk);

			for (; thunkData->u1.Ordinal; ++thunkData, ++iat)
			{
				if (!IMAGE_SNAP_BY_ORDINAL(thunkData->u1.Ordinal))
				{
					auto* importInfo = (IMAGE_IMPORT_BY_NAME*)(base + thunkData->u1.AddressOfData);

					// Check orig/system DLL for this proc, if found then replace IAT with the orig DLL's proc address
					dlog("%s IAT@0x%p (-> 0x%p)", (char*)importInfo->Name, iat, *iat);

					auto origProc = GetProcAddress(orig_module_, (char*)importInfo->Name);
					if (origProc)
					{
						SafeWrite(iat, origProc);
						dlog("-> 0x%p", *iat);
					}
				}
			}

			dlog("Proxy_UpdateImports succeeded");

			return true;
		}
	}

	return false;
}

// Inits proxy by loading original DLL, then calls UpdateImports & UpdateExports
bool Proxy_Attach(HMODULE DllModule)
{
	bool inited = false;
	if (inited)
		return true;

	// Get path of system folder
	WCHAR systemPathW[MAX_PATH] = { 0 };
	// we could use 4096 instead of MAX_PATH, but that makes this function exceed stack size
	// not likely system folder path will be larger than 260 chars anyway

	if (!GetSystemDirectoryW(systemPathW, _countof(systemPathW)))
	{
		dlog("Proxy_Attach: failed to get system directory!");
		return false;
	}

	// Find our DLLs path
	WCHAR modulePathW[4096] = { 0 };
	GetModuleFileNameW(DllModule, modulePathW, _countof(modulePathW));

	std::filesystem::path systemPath = systemPathW;
	std::filesystem::path modulePath = modulePathW;

	auto origDllPath = systemPath / modulePath.filename();

	orig_module_ = LoadLibraryW(origDllPath.wstring().c_str());
	if (!orig_module_)
	{
		dlog("Proxy_Attach: failed to load original module %s", origDllPath.string().c_str());
		return false;
	}

	// Redirect imports to the orig library, removes a small layer of indirection
	if (!Proxy_UpdateImports(modulePath.filename().string().c_str()))
	{
		dlog("Proxy_UpdateImports FAILED!");
		return false;
	}

	// Store pointers for orig library funcs, in case our wrapper versions get used for some reason (eg. GetProcAddress, which it sadly seems we can't redirect...)
	XInputGetCapabilities_orig = (XInputGetCapabilities_ptr)GetProcAddress(orig_module_, "XInputGetCapabilities");
	XInputGetDSoundAudioDeviceGuids_orig = (XInputGetDSoundAudioDeviceGuids_ptr)GetProcAddress(orig_module_, "XInputGetDSoundAudioDeviceGuids");
	XInputGetState_orig = (XInputGetState_ptr)GetProcAddress(orig_module_, "XInputGetState");
	XInputSetState_orig = (XInputSetState_ptr)GetProcAddress(orig_module_, "XInputSetState");
	if (!XInputGetCapabilities_orig || !XInputGetDSoundAudioDeviceGuids_orig || !XInputGetState_orig || !XInputSetState_orig)
	{
		if (!XInputGetCapabilities_orig)
			dlog("Failed to get XInputGetCapabilities address!");
		if (!XInputGetDSoundAudioDeviceGuids_orig)
			dlog("Failed to get XInputGetDSoundAudioDeviceGuids address!");
		if (!XInputGetState_orig)
			dlog("Failed to get XInputGetState address!");
		if (!XInputSetState_orig)
			dlog("Failed to get XInputSetState address!");
	}

	dlog("Inited = true");

	inited = true;

	return true;
}
