semaphore: Use separate signaled flag to prevent races

This commit is contained in:
squidbus 2024-11-28 19:27:29 -08:00
parent e0675785bd
commit 856564759b

View File

@ -72,6 +72,7 @@ public:
} }
it = wait_list.erase(it); it = wait_list.erase(it);
token_count -= waiter->need_count; token_count -= waiter->need_count;
waiter->was_signaled = true;
waiter->cv.notify_one(); waiter->cv.notify_one();
} }
@ -106,6 +107,7 @@ public:
std::condition_variable cv; std::condition_variable cv;
u32 priority; u32 priority;
s32 need_count; s32 need_count;
bool was_signaled{};
bool was_deleted{}; bool was_deleted{};
bool was_cancled{}; bool was_cancled{};
@ -137,16 +139,17 @@ public:
} }
// Wait until timeout runs out, recording how much remaining time there was. // Wait until timeout runs out, recording how much remaining time there was.
const auto start = std::chrono::high_resolution_clock::now(); const auto start = std::chrono::high_resolution_clock::now();
const auto status = cv.wait_for(lk, std::chrono::microseconds(*timeout)); const auto signaled = cv.wait_for(lk, std::chrono::microseconds(*timeout),
[this] { return was_signaled; });
const auto end = std::chrono::high_resolution_clock::now(); const auto end = std::chrono::high_resolution_clock::now();
const auto time = const auto time =
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count(); std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
if (status == std::cv_status::timeout) { if (signaled) {
*timeout = 0;
} else {
*timeout -= time; *timeout -= time;
} else {
*timeout = 0;
} }
return GetResult(status == std::cv_status::timeout); return GetResult(!signaled);
} }
}; };