Vincent's Blog

Multiprocessing in Python

Gepubliceerd op

Categorie: misc

Tags: python

Python bevat een aantal functies om code meer parallel uit te voeren: Threading, Multiprocessing, AsyncIO. Multiprocessing komt met enkele handige API's waarmee code zeer eenvoudig over meerdere cores (of zelfs meerdere computers) kan worden uitgevoerd. In deze post geef ik een korte demo hoe dit kan worden toegepast op python code.

Multiprocessing: Pool based multiprocessing

Verreweg de meest eenvoudige manier om een hoop berekeningen parallel uit te voeren is met multiprocessing.pool. Hiermee hoeven we enkel een array te vullen met verschillende inputs voor een functie die vervolgens parallel zal worden uitgevoerd over alle beschikbare cores.

Ik gebruik hier express een trage en zware manier om priemgetallen te checken, om een beter resultaat te geven.

#!/usr/bin/env python3
import multiprocessing
import time
import math

def is_prime_heavy(n):
    """Heavy way to check on purpose, to be a better demo item"""
    if n < 2: return False
    for i in range(2, int(math.sqrt(n)) + 1):
        if n % i == 0:
            return False
    return True

def pool_demo():
    range_start = 100_000_000
    amount = 1_000_000
    numbers = range(range_start, range_start + amount)

    # 1. Sequential (1 core)
    print(f"--- Start test with 1 core (Sequential) ---")
    start_time = time.time()
    result_single = [is_prime_heavy(n) for n in numbers]
    time_single = time.time() - start_time
    print(f"Finished in: {time_single:.2f}s\n")

    # 2. Multiprocessing (All cores)
    num_cores = multiprocessing.cpu_count()
    print(f"--- Start test with {num_cores} cores (Pool) ---")

    start_time = time.time()
    with multiprocessing.Pool(processes=num_cores) as pool:
        result_multi = pool.map(is_prime_heavy, numbers)

    time_multi = time.time() - start_time
    print(f"Finished in: {time_multi:.2f}s\n")

    # Conclusie
    speedup = time_single / time_multi
    print(f"Pool based was {speedup:.1f}x faster compared to sequential.")

if __name__ == "__main__":
    pool_demo()

Als we dit uitvoeren zien we het volgende resultaat op een rpi5:

Zoals we zien is het met Pool +- 4x zo snel, te verwachten aangezien er 4 cores beschikbaar zijn.

Multiprocessing: Over meerdere hosts spreiden

Nu kunnen we nog een stap verder gaan: Spreiden over meerdere hosts. Dit doen we door een server - client model te maken. Een server start taken en zet ze op een queue, de clients (workers) verbinden met de server en halen taken op van de centrale queue

De volgende code start de server:

#!/usr/bin/env python3
import queue
import time
from multiprocessing.managers import BaseManager
import threading

class QueueManager(BaseManager): pass

def run_server(port=5001, authkey=b'prime', chunk_size=50000):
    task_queue = queue.Queue()
    result_queue = queue.Queue()

    QueueManager.register('get_task_queue', callable=lambda: task_queue)
    QueueManager.register('get_result_queue', callable=lambda: result_queue)

    manager = QueueManager(address=('', port), authkey=authkey)
    server = manager.get_server()

    # Fill queue with CHUNKS instead of single numbers
    range_start = 500_000_000
    total_amount = 10_000_000

    for i in range(range_start, range_start + total_amount, chunk_size):
        # Each task is a tuple: (start_num, end_num)
        task_queue.put((i, min(i + chunk_size, range_start + total_amount)))

    server_thread = threading.Thread(target=server.serve_forever, daemon=True)
    server_thread.start()

    num_chunks = task_queue.qsize()
    print(f"--- Server runs on port {port} ---")
    print(f"Total numbers: {total_amount} | Chunks in Queue: {num_chunks}")

    start_time = time.time()

    while not task_queue.empty():
        remaining = task_queue.qsize()
        progress = ((num_chunks - remaining) / num_chunks) * 100
        print(f"Progress: {progress:.1f}% ({remaining} chunks remaining)", end="\r")
        time.sleep(1)

    print("\n\nQueue empty! Waiting for final results...")
    time.sleep(2) 

    total_time = time.time() - start_time
    # Since workers put lists of primes back, we count the sum of lengths if needed, 
    # but here we just count how many result chunks came back.
    total_found = 0
    while not result_queue.empty():
        total_found += result_queue.get()

    print(f"Total time: {total_time:.2f}s")
    print(f"Efficiency: {total_amount / total_time:.2f} numbers/sec")
    print(f"Total Primes Found: {total_found}")

if __name__ == "__main__":
    run_server()

Nu kunnen we de worker maken:

#!/usr/bin/env python3
import argparse
import math
import multiprocessing
from multiprocessing.managers import BaseManager

class QueueManager(BaseManager): pass

def is_prime_heavy(n):
    if n < 2: return False
    for i in range(2, int(math.sqrt(n)) + 1):
        if n % i == 0:
            return False
    return True

def process_tasks(host, port, authkey):
    QueueManager.register('get_task_queue')
    QueueManager.register('get_result_queue')

    manager = QueueManager(address=(host, port), authkey=authkey.encode('utf-8'))

    try:
        manager.connect()
    except Exception as e:
        print(f"Connection error: {e}")
        return

    tasks = manager.get_task_queue()
    results = manager.get_result_queue()

    while True:
        try:
            # Get a chunk (start, end)
            start, end = tasks.get_nowait()

            # Process chunk locally (no network traffic during this loop)
            chunk_primes_count = 0
            for n in range(start, end):
                if is_prime_heavy(n):
                    chunk_primes_count += 1

            # Send back the result of the whole chunk
            results.put(chunk_primes_count)

        except Exception: # Empty queue
            break

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default="127.0.0.1")
    parser.add_argument("--port", type=int, default=5001)
    parser.add_argument("--key", default="prime")
    args = parser.parse_args()

    aantal_cores = multiprocessing.cpu_count()
    procs = []
    for i in range(aantal_cores):
        p = multiprocessing.Process(target=process_tasks, args=(args.host, args.port, args.key))
        p.start()
        procs.append(p)

    for p in procs:
        p.join()
    print("Worker finished processing.")

Als we de worker ook op een andere raspberrypi installeren (in dit geval een rpi4) kunnen we het als volgt uitvoeren op onze rpi5 + rpi4:

# vanaf rpi5
python3 server.py & \
sleep 2; \
python3 worker.py --host 127.0.0.1 --port 5001 & \
ssh vincent@192.168.1.19 "python3 ~/worker.py --host $(hostname -I | awk '{print $1}') --port 5001"

Als we dit uitvoeren duurt het nog maar 180 seconden! Dit is weer een hele versnelling!