1
0
Fork 0
mirror of git://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git synced 2025-06-08 00:07:34 +09:00
linux/rust/pin-init/examples/mutex.rs
Benno Lossin 5c4167b405 rust: pin-init: examples: conditionally enable feature(lint_reasons)
`lint_reasons` is unstable in Rust 1.80 and earlier, enable it
conditionally in the examples to allow compiling them with older
compilers.

Link: ec494fe686
Link: https://lore.kernel.org/all/20250414195928.129040-3-benno.lossin@proton.me
Signed-off-by: Benno Lossin <benno.lossin@proton.me>
2025-04-21 23:19:12 +02:00

210 lines
5.3 KiB
Rust

// SPDX-License-Identifier: Apache-2.0 OR MIT
#![allow(clippy::undocumented_unsafe_blocks)]
#![cfg_attr(feature = "alloc", feature(allocator_api))]
#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
#![allow(clippy::missing_safety_doc)]
use core::{
cell::{Cell, UnsafeCell},
marker::PhantomPinned,
ops::{Deref, DerefMut},
pin::Pin,
sync::atomic::{AtomicBool, Ordering},
};
use std::{
sync::Arc,
thread::{self, park, sleep, Builder, Thread},
time::Duration,
};
use pin_init::*;
#[expect(unused_attributes)]
#[path = "./linked_list.rs"]
pub mod linked_list;
use linked_list::*;
pub struct SpinLock {
inner: AtomicBool,
}
impl SpinLock {
#[inline]
pub fn acquire(&self) -> SpinLockGuard<'_> {
while self
.inner
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
while self.inner.load(Ordering::Relaxed) {
thread::yield_now();
}
}
SpinLockGuard(self)
}
#[inline]
#[allow(clippy::new_without_default)]
pub const fn new() -> Self {
Self {
inner: AtomicBool::new(false),
}
}
}
pub struct SpinLockGuard<'a>(&'a SpinLock);
impl Drop for SpinLockGuard<'_> {
#[inline]
fn drop(&mut self) {
self.0.inner.store(false, Ordering::Release);
}
}
#[pin_data]
pub struct CMutex<T> {
#[pin]
wait_list: ListHead,
spin_lock: SpinLock,
locked: Cell<bool>,
#[pin]
data: UnsafeCell<T>,
}
impl<T> CMutex<T> {
#[inline]
pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
pin_init!(CMutex {
wait_list <- ListHead::new(),
spin_lock: SpinLock::new(),
locked: Cell::new(false),
data <- unsafe {
pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
val.__pinned_init(slot.cast::<T>())
})
},
})
}
#[inline]
pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
let mut sguard = self.spin_lock.acquire();
if self.locked.get() {
stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
// println!("wait list length: {}", self.wait_list.size());
while self.locked.get() {
drop(sguard);
park();
sguard = self.spin_lock.acquire();
}
// This does have an effect, as the ListHead inside wait_entry implements Drop!
#[expect(clippy::drop_non_drop)]
drop(wait_entry);
}
self.locked.set(true);
unsafe {
Pin::new_unchecked(CMutexGuard {
mtx: self,
_pin: PhantomPinned,
})
}
}
#[allow(dead_code)]
pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
// SAFETY: we have an exclusive reference and thus nobody has access to data.
unsafe { &mut *self.data.get() }
}
}
unsafe impl<T: Send> Send for CMutex<T> {}
unsafe impl<T: Send> Sync for CMutex<T> {}
pub struct CMutexGuard<'a, T> {
mtx: &'a CMutex<T>,
_pin: PhantomPinned,
}
impl<T> Drop for CMutexGuard<'_, T> {
#[inline]
fn drop(&mut self) {
let sguard = self.mtx.spin_lock.acquire();
self.mtx.locked.set(false);
if let Some(list_field) = self.mtx.wait_list.next() {
let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
unsafe { (*wait_entry).thread.unpark() };
}
drop(sguard);
}
}
impl<T> Deref for CMutexGuard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.mtx.data.get() }
}
}
impl<T> DerefMut for CMutexGuard<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.mtx.data.get() }
}
}
#[pin_data]
#[repr(C)]
struct WaitEntry {
#[pin]
wait_list: ListHead,
thread: Thread,
}
impl WaitEntry {
#[inline]
fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
pin_init!(Self {
thread: thread::current(),
wait_list <- ListHead::insert_prev(list),
})
}
}
#[cfg(not(any(feature = "std", feature = "alloc")))]
fn main() {}
#[allow(dead_code)]
#[cfg_attr(test, test)]
#[cfg(any(feature = "std", feature = "alloc"))]
fn main() {
let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
let mut handles = vec![];
let thread_count = 20;
let workload = if cfg!(miri) { 100 } else { 1_000 };
for i in 0..thread_count {
let mtx = mtx.clone();
handles.push(
Builder::new()
.name(format!("worker #{i}"))
.spawn(move || {
for _ in 0..workload {
*mtx.lock() += 1;
}
println!("{i} halfway");
sleep(Duration::from_millis((i as u64) * 10));
for _ in 0..workload {
*mtx.lock() += 1;
}
println!("{i} finished");
})
.expect("should not fail"),
);
}
for h in handles {
h.join().expect("thread panicked");
}
println!("{:?}", &*mtx.lock());
assert_eq!(*mtx.lock(), workload * thread_count * 2);
}