#include "stdafx.h"
#include <fstream>

void* ModuleDirectoryEntryData(void* Module, int DirectoryEntry, int* EntrySize)
{
	auto* base = (BYTE*)Module;
	auto* dosHeader = (IMAGE_DOS_HEADER*)base;
	if (dosHeader->e_magic != IMAGE_DOS_SIGNATURE)
		return nullptr; // invalid header :(

	auto* ntHeader = (IMAGE_NT_HEADERS*)(base + dosHeader->e_lfanew);
	if (ntHeader->Signature != IMAGE_NT_SIGNATURE)
		return nullptr; // invalid header :(

	auto entryAddr = ntHeader->OptionalHeader.DataDirectory[DirectoryEntry].VirtualAddress;
	if (!entryAddr)
		return nullptr;

	if (EntrySize)
		*EntrySize = ntHeader->OptionalHeader.DataDirectory[DirectoryEntry].Size;

	return base + entryAddr;
}

bool FileExists(std::filesystem::path File)
{
	return GetFileAttributesW(File.wstring().c_str()) != -1;
}

FARPROC* GetIATPointer(void* Module, const char* LibraryName, const char* ImportName)
{
	auto* base = (BYTE*)Module;
	auto* importTable = (IMAGE_IMPORT_DESCRIPTOR*)ModuleDirectoryEntryData(Module, IMAGE_DIRECTORY_ENTRY_IMPORT);
	if (!importTable)
		return nullptr;

	for (; importTable->Characteristics; ++importTable)
	{
		auto* dllName = (const char*)(base + importTable->Name);

		if (!_stricmp(dllName, LibraryName))
		{
			// found the dll

			auto* thunkData = (IMAGE_THUNK_DATA*)(base + importTable->OriginalFirstThunk);
			auto* iat = (FARPROC*)(base + importTable->FirstThunk);

			for (; thunkData->u1.Ordinal; ++thunkData, ++iat)
			{
				if (!IMAGE_SNAP_BY_ORDINAL(thunkData->u1.Ordinal))
				{
					auto* importInfo = (IMAGE_IMPORT_BY_NAME*)(base + thunkData->u1.AddressOfData);

					if (!_stricmp((char*)importInfo->Name, ImportName))
					{
						// found the import
						return iat;
					}
				}
			}

			return nullptr;
		}
	}

	return nullptr;
}

FARPROC* GetExpPointer(void* Module, const char* ProcName)
{
	// TODO: handle forwarded exports? (eg. NTDLL.RtlDecodePointer inside kernelbase.dll)
	// see https://github.com/Speedi13/Custom-GetProcAddress-and-GetModuleHandle-and-more/blob/master/CustomWinApi.cpp
	// might not be necessary though, maybe caller should already know if it's forwarded or not?

	auto* base = (BYTE*)Module;
	auto* exportTable = (IMAGE_EXPORT_DIRECTORY*)ModuleDirectoryEntryData(Module, IMAGE_DIRECTORY_ENTRY_EXPORT);
	if (!exportTable)
		return nullptr;

	auto* functions = (FARPROC*)(base + exportTable->AddressOfFunctions);
	auto* nameRvas = (DWORD*)(base + exportTable->AddressOfNames);
	auto* ordinals = (WORD*)(base + exportTable->AddressOfNameOrdinals);

	bool searchName = (size_t)ProcName > 0xFFFF; // should we search for it as name or ordinal?

	for (DWORD i = 0; i < exportTable->NumberOfFunctions; i++)
	{
		auto* func = &functions[i];

		if (searchName)
		{
			auto* wtf = &nameRvas[i];
			auto nameRva = nameRvas[i];
			if (nameRva)
			{
				char* name = (char*)(base + nameRva);
				if (!_stricmp(ProcName, name))
					return func;
			}
		}
		else
		{
			WORD ordinal = ordinals[i];
			if (ordinal == ((ULONG_PTR)ProcName & 0xFFFF))
				return func;
		}
	}

	return nullptr;
}

void RedirectProcAddress(const char* ModuleName, const char* ProcName, FARPROC NewProc, FARPROC* OldProc)
{
	if (OldProc)
		*OldProc = 0;

	auto moduleHandle = GetModuleHandleA(ModuleName);

	auto* exportAddress = (FARPROC*)GetExpPointer(moduleHandle, ProcName);
	auto* importAddress = (FARPROC*)GetIATPointer(GetModuleHandleA(NULL), ModuleName, ProcName);

	if (exportAddress)
	{
		if (OldProc)
			*OldProc = *exportAddress;
		SafeWrite(exportAddress, NewProc);
	}

	if (importAddress)
	{
		if (OldProc)
			*OldProc = *importAddress;
		SafeWrite(importAddress, NewProc);
	}
}

extern std::filesystem::path dll_path_;

void dlog(const char* Format, ...)
{
#ifdef _DEBUG
	char* str = new char[4096];
	va_list ap;
	va_start(ap, Format);

	vsnprintf(str, 4096, Format, ap);
	va_end(ap);

	std::ofstream file;
	file.open(dll_path_.parent_path() / "inject.log", std::ofstream::out | std::ofstream::app);
	if (!file.is_open())
		return; // wtf

	file << str << "\n";

	file.close();
#endif
}
