diff --git a/src/core/cpu_patches.cpp b/src/core/cpu_patches.cpp index 1b159d32b..8850d4f99 100644 --- a/src/core/cpu_patches.cpp +++ b/src/core/cpu_patches.cpp @@ -11,7 +11,6 @@ #include "common/assert.h" #include "common/types.h" #include "core/signals.h" -#include "core/tls.h" #include "cpu_patches.h" #ifdef _WIN32 @@ -538,52 +537,6 @@ static bool FilterRosetta2Only(const ZydisDecodedOperand*) { return ret; } -#else // __APPLE__ - -static bool FilterTcbAccess(const ZydisDecodedOperand* operands) { - const auto& dst_op = operands[0]; - const auto& src_op = operands[1]; - - // Patch only 'mov (64-bit register), fs:[0]' - return src_op.type == ZYDIS_OPERAND_TYPE_MEMORY && src_op.mem.segment == ZYDIS_REGISTER_FS && - src_op.mem.base == ZYDIS_REGISTER_NONE && src_op.mem.index == ZYDIS_REGISTER_NONE && - src_op.mem.disp.value == 0 && dst_op.reg.value >= ZYDIS_REGISTER_RAX && - dst_op.reg.value <= ZYDIS_REGISTER_R15; -} - -static void GenerateTcbAccess(const ZydisDecodedOperand* operands, Xbyak::CodeGenerator& c) { - const auto dst = ZydisToXbyakRegisterOperand(operands[0]); - -#if defined(_WIN32) - // The following logic is based on the Kernel32.dll asm of TlsGetValue - static constexpr u32 TlsSlotsOffset = 0x1480; - static constexpr u32 TlsExpansionSlotsOffset = 0x1780; - static constexpr u32 TlsMinimumAvailable = 64; - - const auto slot = GetTcbKey(); - - // Load the pointer to the table of TLS slots. - c.putSeg(gs); - if (slot < TlsMinimumAvailable) { - // Load the pointer to TLS slots. - c.mov(dst, ptr[reinterpret_cast(TlsSlotsOffset + slot * sizeof(LPVOID))]); - } else { - const u32 tls_index = slot - TlsMinimumAvailable; - - // Load the pointer to the table of TLS expansion slots. - c.mov(dst, ptr[reinterpret_cast(TlsExpansionSlotsOffset)]); - // Load the pointer to our buffer. - c.mov(dst, qword[dst + tls_index * sizeof(LPVOID)]); - } -#else - const auto src = ZydisToXbyakMemoryOperand(operands[1]); - - // Replace fs read with gs read. - c.putSeg(gs); - c.mov(dst, src); -#endif -} - #endif // __APPLE__ using PatchFilter = bool (*)(const ZydisDecodedOperand*); @@ -600,13 +553,6 @@ struct PatchInfo { }; static const std::unordered_map Patches = { -#if defined(_WIN32) - // Windows needs a trampoline. - {ZYDIS_MNEMONIC_MOV, {FilterTcbAccess, GenerateTcbAccess, true}}, -#elif !defined(__APPLE__) - {ZYDIS_MNEMONIC_MOV, {FilterTcbAccess, GenerateTcbAccess, false}}, -#endif - #ifdef __APPLE__ // Patches for instruction sets not supported by Rosetta 2. // BMI1 @@ -658,15 +604,26 @@ static PatchModule* GetModule(const void* ptr) { return &(std::prev(upper_bound)->second); } -/// Returns a boolean indicating whether the instruction was patched, and the offset to advance past -/// whatever is at the current code pointer. -static std::pair TryPatch(u8* code, PatchModule* module) { +static bool TryPatch(void* code_address) { + auto* code = static_cast(code_address); + auto* module = GetModule(code); + if (module == nullptr) { + return false; + } + + std::unique_lock lock{module->mutex}; + + // Return early if already patched, in case multiple threads signaled at the same time. + if (std::ranges::find(module->patched, code) != module->patched.end()) { + return true; + } + ZydisDecodedInstruction instruction; ZydisDecodedOperand operands[ZYDIS_MAX_OPERAND_COUNT]; const auto status = ZydisDecoderDecodeFull(&instr_decoder, code, module->end - code, &instruction, operands); if (!ZYAN_SUCCESS(status)) { - return std::make_pair(false, 1); + return false; } if (Patches.contains(instruction.mnemonic)) { @@ -706,52 +663,20 @@ static std::pair TryPatch(u8* code, PatchModule* module) { module->patched.insert(code); LOG_DEBUG(Core, "Patched instruction '{}' at: {}", ZydisMnemonicGetString(instruction.mnemonic), fmt::ptr(code)); - return std::make_pair(true, instruction.length); + return true; } } } - return std::make_pair(false, instruction.length); -} - -static bool TryPatchJit(void* code_address) { - auto* code = static_cast(code_address); - auto* module = GetModule(code); - if (module == nullptr) { - return false; - } - - std::unique_lock lock{module->mutex}; - - // Return early if already patched, in case multiple threads signaled at the same time. - if (std::ranges::find(module->patched, code) != module->patched.end()) { - return true; - } - - return TryPatch(code, module).first; -} - -static void TryPatchAot(void* code_address, u64 code_size) { - auto* code = static_cast(code_address); - auto* module = GetModule(code); - if (module == nullptr) { - return; - } - - std::unique_lock lock{module->mutex}; - - const auto* end = code + code_size; - while (code < end) { - code += TryPatch(code, module).second; - } + return false; } static bool PatchesAccessViolationHandler(void* code_address, void* fault_address, bool is_write) { - return TryPatchJit(code_address); + return TryPatch(code_address); } static bool PatchesIllegalInstructionHandler(void* code_address) { - return TryPatchJit(code_address); + return TryPatch(code_address); } static void PatchesInit() { @@ -761,7 +686,6 @@ static void PatchesInit() { auto* signals = Signals::Instance(); // Should be called last. constexpr auto priority = std::numeric_limits::max(); - signals->RegisterAccessViolationHandler(PatchesAccessViolationHandler, priority); signals->RegisterIllegalInstructionHandler(PatchesIllegalInstructionHandler, priority); } } @@ -785,15 +709,51 @@ void PrePatchInstructions(u64 segment_addr, u64 segment_size) { auto* code_page = reinterpret_cast(Common::AlignUp(segment_addr, 0x1000)); const auto* end_page = code_page + Common::AlignUp(segment_size, 0x1000); while (code_page < end_page) { - TryPatchJit(code_page); + TryPatch(code_page); code_page += 0x1000; } } #elif !defined(_WIN32) - // Linux and others have an FS segment pointing to valid memory, so continue to do full - // ahead-of-time patching for now until a better solution is worked out. - if (!Patches.empty()) { - TryPatchAot(reinterpret_cast(segment_addr), segment_size); + // On Linux and similar, we need to patch FS accesses ahead of time since FS points to valid + // memory. + // TODO: Replace this with swapping in/out the correct FS around guest code. + auto* code = reinterpret_cast(segment_addr); + const auto* end = code + segment_size; + while (code < end) { + ZydisDecodedInstruction instruction; + ZydisDecodedOperand operands[ZYDIS_MAX_OPERAND_COUNT]; + const auto status = + ZydisDecoderDecodeFull(&instr_decoder, code, end - code, &instruction, operands); + if (!ZYAN_SUCCESS(status)) { + code += 1; + continue; + } + + if (instruction.mnemonic == ZYDIS_MNEMONIC_MOV) { + const auto& dst_op = operands[0]; + const auto& src_op = operands[1]; + + // Patch only 'mov (64-bit register), fs:[0]' + if (src_op.type == ZYDIS_OPERAND_TYPE_MEMORY && + src_op.mem.segment == ZYDIS_REGISTER_FS && src_op.mem.base == ZYDIS_REGISTER_NONE && + src_op.mem.index == ZYDIS_REGISTER_NONE && src_op.mem.disp.value == 0 && + dst_op.reg.value >= ZYDIS_REGISTER_RAX && dst_op.reg.value <= ZYDIS_REGISTER_R15) { + const auto dst = ZydisToXbyakRegisterOperand(operands[0]); + const auto src = ZydisToXbyakMemoryOperand(operands[1]); + + // Replace fs read with gs read. + Xbyak::CodeGenerator gen(instruction.length, code); + gen.putSeg(gs); + gen.mov(dst, src); + + const u64 remaining = instruction.length - (gen.getCurr() - code); + if (remaining > 0) { + gen.nop(instruction.length - (gen.getCurr() - code)); + } + } + } + + code += instruction.length; } #endif } diff --git a/src/core/tls.cpp b/src/core/tls.cpp index eb07e7a72..333d40a65 100644 --- a/src/core/tls.cpp +++ b/src/core/tls.cpp @@ -5,9 +5,12 @@ #include "common/arch.h" #include "common/assert.h" #include "common/types.h" +#include "core/signals.h" #include "core/tls.h" #ifdef _WIN32 +#include +#include #include #elif defined(__APPLE__) && defined(ARCH_X86_64) #include @@ -25,14 +28,43 @@ namespace Core { // Windows static DWORD slot = 0; +static ZydisDecoder tls_instr_decoder; static std::once_flag slot_alloc_flag; -static void AllocTcbKey() { - slot = TlsAlloc(); +static bool TlsAccessViolationHandler(void* code_address, void* fault_address, bool is_write) { + ZydisDecodedInstruction instruction; + ZydisDecodedOperand operands[ZYDIS_MAX_OPERAND_COUNT]; + const auto status = + ZydisDecoderDecodeFull(&tls_instr_decoder, code_address, 0x20, &instruction, operands); + if (!ZYAN_SUCCESS(status)) { + return false; + } + + for (u32 i = 0; i < instruction.operand_count_visible; i++) { + if (operands[i].type == ZYDIS_OPERAND_TYPE_MEMORY && + operands[i].mem.segment == ZYDIS_REGISTER_FS) { + // Set the FS register and try again. + const auto* tcb_base = GetTcbBase(); + asm volatile("wrfsbase %0" ::"r"(tcb_base) : "memory"); + return true; + } + } + + return false; } -u32 GetTcbKey() { - std::call_once(slot_alloc_flag, &AllocTcbKey); +static void InitializeTls() { + slot = TlsAlloc(); + ZydisDecoderInit(&tls_instr_decoder, ZYDIS_MACHINE_MODE_LONG_64, ZYDIS_STACK_WIDTH_64); + + auto* signals = Signals::Instance(); + // Should be called last. + constexpr auto priority = std::numeric_limits::max(); + signals->RegisterAccessViolationHandler(TlsAccessViolationHandler, priority); +} + +static u32 GetTcbKey() { + std::call_once(slot_alloc_flag, &InitializeTls); return slot; } diff --git a/src/core/tls.h b/src/core/tls.h index f5bf33184..e5ef7d714 100644 --- a/src/core/tls.h +++ b/src/core/tls.h @@ -22,11 +22,6 @@ struct Tcb { void* tcb_thread; }; -#ifdef _WIN32 -/// Gets the thread local storage key for the TCB block. -u32 GetTcbKey(); -#endif - /// Sets the data pointer to the TCB block. void SetTcbBase(void* image_address);