core: Replace Windows TLS code patches with exception handler.

This commit is contained in:
squidbus 2024-09-16 22:08:15 -07:00
parent 28ec489dbe
commit c70014a992
3 changed files with 96 additions and 109 deletions

View File

@ -11,7 +11,6 @@
#include "common/assert.h"
#include "common/types.h"
#include "core/signals.h"
#include "core/tls.h"
#include "cpu_patches.h"
#ifdef _WIN32
@ -538,52 +537,6 @@ static bool FilterRosetta2Only(const ZydisDecodedOperand*) {
return ret;
}
#else // __APPLE__
static bool FilterTcbAccess(const ZydisDecodedOperand* operands) {
const auto& dst_op = operands[0];
const auto& src_op = operands[1];
// Patch only 'mov (64-bit register), fs:[0]'
return src_op.type == ZYDIS_OPERAND_TYPE_MEMORY && src_op.mem.segment == ZYDIS_REGISTER_FS &&
src_op.mem.base == ZYDIS_REGISTER_NONE && src_op.mem.index == ZYDIS_REGISTER_NONE &&
src_op.mem.disp.value == 0 && dst_op.reg.value >= ZYDIS_REGISTER_RAX &&
dst_op.reg.value <= ZYDIS_REGISTER_R15;
}
static void GenerateTcbAccess(const ZydisDecodedOperand* operands, Xbyak::CodeGenerator& c) {
const auto dst = ZydisToXbyakRegisterOperand(operands[0]);
#if defined(_WIN32)
// The following logic is based on the Kernel32.dll asm of TlsGetValue
static constexpr u32 TlsSlotsOffset = 0x1480;
static constexpr u32 TlsExpansionSlotsOffset = 0x1780;
static constexpr u32 TlsMinimumAvailable = 64;
const auto slot = GetTcbKey();
// Load the pointer to the table of TLS slots.
c.putSeg(gs);
if (slot < TlsMinimumAvailable) {
// Load the pointer to TLS slots.
c.mov(dst, ptr[reinterpret_cast<void*>(TlsSlotsOffset + slot * sizeof(LPVOID))]);
} else {
const u32 tls_index = slot - TlsMinimumAvailable;
// Load the pointer to the table of TLS expansion slots.
c.mov(dst, ptr[reinterpret_cast<void*>(TlsExpansionSlotsOffset)]);
// Load the pointer to our buffer.
c.mov(dst, qword[dst + tls_index * sizeof(LPVOID)]);
}
#else
const auto src = ZydisToXbyakMemoryOperand(operands[1]);
// Replace fs read with gs read.
c.putSeg(gs);
c.mov(dst, src);
#endif
}
#endif // __APPLE__
using PatchFilter = bool (*)(const ZydisDecodedOperand*);
@ -600,13 +553,6 @@ struct PatchInfo {
};
static const std::unordered_map<ZydisMnemonic, PatchInfo> Patches = {
#if defined(_WIN32)
// Windows needs a trampoline.
{ZYDIS_MNEMONIC_MOV, {FilterTcbAccess, GenerateTcbAccess, true}},
#elif !defined(__APPLE__)
{ZYDIS_MNEMONIC_MOV, {FilterTcbAccess, GenerateTcbAccess, false}},
#endif
#ifdef __APPLE__
// Patches for instruction sets not supported by Rosetta 2.
// BMI1
@ -658,15 +604,26 @@ static PatchModule* GetModule(const void* ptr) {
return &(std::prev(upper_bound)->second);
}
/// Returns a boolean indicating whether the instruction was patched, and the offset to advance past
/// whatever is at the current code pointer.
static std::pair<bool, u64> TryPatch(u8* code, PatchModule* module) {
static bool TryPatch(void* code_address) {
auto* code = static_cast<u8*>(code_address);
auto* module = GetModule(code);
if (module == nullptr) {
return false;
}
std::unique_lock lock{module->mutex};
// Return early if already patched, in case multiple threads signaled at the same time.
if (std::ranges::find(module->patched, code) != module->patched.end()) {
return true;
}
ZydisDecodedInstruction instruction;
ZydisDecodedOperand operands[ZYDIS_MAX_OPERAND_COUNT];
const auto status =
ZydisDecoderDecodeFull(&instr_decoder, code, module->end - code, &instruction, operands);
if (!ZYAN_SUCCESS(status)) {
return std::make_pair(false, 1);
return false;
}
if (Patches.contains(instruction.mnemonic)) {
@ -706,52 +663,20 @@ static std::pair<bool, u64> TryPatch(u8* code, PatchModule* module) {
module->patched.insert(code);
LOG_DEBUG(Core, "Patched instruction '{}' at: {}",
ZydisMnemonicGetString(instruction.mnemonic), fmt::ptr(code));
return std::make_pair(true, instruction.length);
return true;
}
}
}
return std::make_pair(false, instruction.length);
}
static bool TryPatchJit(void* code_address) {
auto* code = static_cast<u8*>(code_address);
auto* module = GetModule(code);
if (module == nullptr) {
return false;
}
std::unique_lock lock{module->mutex};
// Return early if already patched, in case multiple threads signaled at the same time.
if (std::ranges::find(module->patched, code) != module->patched.end()) {
return true;
}
return TryPatch(code, module).first;
}
static void TryPatchAot(void* code_address, u64 code_size) {
auto* code = static_cast<u8*>(code_address);
auto* module = GetModule(code);
if (module == nullptr) {
return;
}
std::unique_lock lock{module->mutex};
const auto* end = code + code_size;
while (code < end) {
code += TryPatch(code, module).second;
}
}
static bool PatchesAccessViolationHandler(void* code_address, void* fault_address, bool is_write) {
return TryPatchJit(code_address);
return TryPatch(code_address);
}
static bool PatchesIllegalInstructionHandler(void* code_address) {
return TryPatchJit(code_address);
return TryPatch(code_address);
}
static void PatchesInit() {
@ -761,7 +686,6 @@ static void PatchesInit() {
auto* signals = Signals::Instance();
// Should be called last.
constexpr auto priority = std::numeric_limits<u32>::max();
signals->RegisterAccessViolationHandler(PatchesAccessViolationHandler, priority);
signals->RegisterIllegalInstructionHandler(PatchesIllegalInstructionHandler, priority);
}
}
@ -785,15 +709,51 @@ void PrePatchInstructions(u64 segment_addr, u64 segment_size) {
auto* code_page = reinterpret_cast<u8*>(Common::AlignUp(segment_addr, 0x1000));
const auto* end_page = code_page + Common::AlignUp(segment_size, 0x1000);
while (code_page < end_page) {
TryPatchJit(code_page);
TryPatch(code_page);
code_page += 0x1000;
}
}
#elif !defined(_WIN32)
// Linux and others have an FS segment pointing to valid memory, so continue to do full
// ahead-of-time patching for now until a better solution is worked out.
if (!Patches.empty()) {
TryPatchAot(reinterpret_cast<void*>(segment_addr), segment_size);
// On Linux and similar, we need to patch FS accesses ahead of time since FS points to valid
// memory.
// TODO: Replace this with swapping in/out the correct FS around guest code.
auto* code = reinterpret_cast<u8*>(segment_addr);
const auto* end = code + segment_size;
while (code < end) {
ZydisDecodedInstruction instruction;
ZydisDecodedOperand operands[ZYDIS_MAX_OPERAND_COUNT];
const auto status =
ZydisDecoderDecodeFull(&instr_decoder, code, end - code, &instruction, operands);
if (!ZYAN_SUCCESS(status)) {
code += 1;
continue;
}
if (instruction.mnemonic == ZYDIS_MNEMONIC_MOV) {
const auto& dst_op = operands[0];
const auto& src_op = operands[1];
// Patch only 'mov (64-bit register), fs:[0]'
if (src_op.type == ZYDIS_OPERAND_TYPE_MEMORY &&
src_op.mem.segment == ZYDIS_REGISTER_FS && src_op.mem.base == ZYDIS_REGISTER_NONE &&
src_op.mem.index == ZYDIS_REGISTER_NONE && src_op.mem.disp.value == 0 &&
dst_op.reg.value >= ZYDIS_REGISTER_RAX && dst_op.reg.value <= ZYDIS_REGISTER_R15) {
const auto dst = ZydisToXbyakRegisterOperand(operands[0]);
const auto src = ZydisToXbyakMemoryOperand(operands[1]);
// Replace fs read with gs read.
Xbyak::CodeGenerator gen(instruction.length, code);
gen.putSeg(gs);
gen.mov(dst, src);
const u64 remaining = instruction.length - (gen.getCurr() - code);
if (remaining > 0) {
gen.nop(instruction.length - (gen.getCurr() - code));
}
}
}
code += instruction.length;
}
#endif
}

View File

@ -5,9 +5,12 @@
#include "common/arch.h"
#include "common/assert.h"
#include "common/types.h"
#include "core/signals.h"
#include "core/tls.h"
#ifdef _WIN32
#include <Zydis/Zydis.h>
#include <immintrin.h>
#include <windows.h>
#elif defined(__APPLE__) && defined(ARCH_X86_64)
#include <architecture/i386/table.h>
@ -25,14 +28,43 @@ namespace Core {
// Windows
static DWORD slot = 0;
static ZydisDecoder tls_instr_decoder;
static std::once_flag slot_alloc_flag;
static void AllocTcbKey() {
slot = TlsAlloc();
static bool TlsAccessViolationHandler(void* code_address, void* fault_address, bool is_write) {
ZydisDecodedInstruction instruction;
ZydisDecodedOperand operands[ZYDIS_MAX_OPERAND_COUNT];
const auto status =
ZydisDecoderDecodeFull(&tls_instr_decoder, code_address, 0x20, &instruction, operands);
if (!ZYAN_SUCCESS(status)) {
return false;
}
u32 GetTcbKey() {
std::call_once(slot_alloc_flag, &AllocTcbKey);
for (u32 i = 0; i < instruction.operand_count_visible; i++) {
if (operands[i].type == ZYDIS_OPERAND_TYPE_MEMORY &&
operands[i].mem.segment == ZYDIS_REGISTER_FS) {
// Set the FS register and try again.
const auto* tcb_base = GetTcbBase();
asm volatile("wrfsbase %0" ::"r"(tcb_base) : "memory");
return true;
}
}
return false;
}
static void InitializeTls() {
slot = TlsAlloc();
ZydisDecoderInit(&tls_instr_decoder, ZYDIS_MACHINE_MODE_LONG_64, ZYDIS_STACK_WIDTH_64);
auto* signals = Signals::Instance();
// Should be called last.
constexpr auto priority = std::numeric_limits<u32>::max();
signals->RegisterAccessViolationHandler(TlsAccessViolationHandler, priority);
}
static u32 GetTcbKey() {
std::call_once(slot_alloc_flag, &InitializeTls);
return slot;
}

View File

@ -22,11 +22,6 @@ struct Tcb {
void* tcb_thread;
};
#ifdef _WIN32
/// Gets the thread local storage key for the TCB block.
u32 GetTcbKey();
#endif
/// Sets the data pointer to the TCB block.
void SetTcbBase(void* image_address);