请教一个 Python 中线程共享数据的问题

smdbh · 2024-10-8 10:56:44 · 23 次点击
我想在 main 和它建立的线程间共享数据,在线程中执行逻辑,更新数据,主线程中读取判断。
1. 由于数据较多,使用 dataclass 当 struct 用
2. 线程中写,main 中只读,所有没有加锁
实际使用发现,这个数据共享不是完全引用,变量地址(使用 id 查看两边地址)会有改变,导致 main 和 thread 中的变量不是一个东西了,监测失败。
3. tricky 的是,第一次创建的线程没有问题。跑完一次,第二次再来一次就大概率出问题,后续再尝试就一直会出问题了,偶尔会成功。
请问如果要实现多线程共享数据的读写,有什么最佳实现和模板吗
举报· 23 次点击
登录 注册 站外分享
2 条回复  
djangovcps 小成 2024-10-8 11:40:16
threading.lock ?
qianchengv 初学 2024-10-8 12:01:58
```python
import threading
import time
from concurrent.futures import ThreadPoolExecutor
import unittest
from dataclasses import dataclass, field
from threading import Lock
import multiprocessing

@dataclass
class SharedData:
    value: int = 0
    # Using a Lock to ensure thread-safety when accessing shared data
    lock: Lock = field(default_factory=Lock, init=False, repr=False)

    def increment(self):
        with self.lock:
            self.value += 1

    def get_value(self):
        with self.lock:
            return self.value

def worker(data: SharedData, num_iterations: int):
    local_sum = 0
    for _ in range(num_iterations):
        local_sum += 1
    # Use a lock to safely update the shared data
    with data.lock:
        data.value += local_sum

class TestSharedDataThreadSafety(unittest.TestCase):
    def test_concurrent_increments(self):
        shared_data = SharedData()
        # Use 2x CPU count for threads to test both CPU-bound and I/O-bound scenarios
        num_threads = multiprocessing.cpu_count() * 2
        num_iterations = 1000000 // num_threads

        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = [executor.submit(worker, shared_data, num_iterations) for _ in range(num_threads)]
            for future in futures:
                future.result()

        expected_value = num_threads * num_iterations
        self.assertEqual(shared_data.get_value(), expected_value,
                         f"Expected {expected_value}, but got {shared_data.get_value()}")

    def test_race_condition(self):
        shared_data = SharedData()
        race_detected = threading.Event()

        def racer():
            with shared_data.lock:
                initial_value = shared_data.value
                time.sleep(0.001)  # Simulate some work
                # Check if the value has changed, which would indicate a race condition
                if initial_value == shared_data.value:
                    shared_data.value += 1
                else:
                    race_detected.set()

        threads = [threading.Thread(target=racer) for _ in range(100)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        self.assertFalse(race_detected.is_set(), "Race condition detected")

    def test_stress_test(self):
        shared_data = SharedData()
        stop_flag = threading.Event()

        def stress_worker():
            local_sum = 0
            while not stop_flag.is_set():
                local_sum += 1
            # Use a lock to safely update the shared data after intensive local computation
            with shared_data.lock:
                shared_data.value += local_sum

        # Use CPU count for threads to maximize resource utilization
        threads = [threading.Thread(target=stress_worker) for _ in range(multiprocessing.cpu_count())]
        for t in threads:
            t.start()

        time.sleep(5)  # Run for 5 seconds to simulate prolonged stress
        stop_flag.set()

        for t in threads:
            t.join()

        print(f"Stress test final value: {shared_data.get_value()}")

if __name__ == '__main__':
    unittest.main()
```
返回顶部