from util.generic.log import Log from dataclasses import dataclass from enum import Enum import sys import time from datetime import datetime from multiprocessing import Pool, Manager import threading from logging import Logger from typing import Optional import re @dataclass class Spinner: char_list: list[str] delay: float class Spinners(Enum): BASIC = Spinner(["|", "/", "-", "\\"], 0.1) SPIN_TRI_BLOCK = Spinner(["▙", "▛", "▜", "▟"], 0.1) SPIN_RIGHT_ANGLE = Spinner(["🮤", "🮧", "🮥", "🮦"], 0.1) SPIN_OPEN_CUBE = Spinner(["🮪", "🮫", "🮭", "🮬"], 0.1) SPIN_DOTS = Spinner(["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"], 0.1) GUY_DANCING = Spinner(["🯇", "🯈"], 0.3) GUY_JUMPING = Spinner(["🯅", "🯆"], 0.3) SHIFTING_PATTERN = Spinner(["🮕", "🮖"], 0.1) SHIFTING_GRADIENT = Spinner(["░▒▓█▓▒░", "▒▓█▓▒░░", "▓█▓▒░▒▓", "█▓▒░▒▓█", "▓▒░▒▓█▓", "░░▒▓█▓▒"], 0.1) @dataclass class WorkerProcess: function: callable instructions: tuple def dispatch_worker(worker_process: WorkerProcess): try: result = worker_process.function(worker_process.instructions) return result except Exception as e: raise @dataclass class WorkerReturn: status: bool command: str stdout: str = "" stderr: str = "" error: Optional[str] = None def output_result(self, logger: Logger, success_regex: str = None, success_message: str = None): if self.status: result_level = Log.Level.INFO if success_regex and success_message: success_match = re.search(success_regex, self.command) success_string = success_match.group(1) if success_match else "unknown_target" logger.log(result_level, success_message.format(output = success_string)) else: logger.log(result_level, f"SUCCESS: {self.command}") else: result_level = Log.Level.ERROR logger.log(result_level, f"FAILURE: {self.command}") if self.error: logger.log(result_level, f"Error: {self.error}") if self.stdout: logger.log(result_level, f"STDOUT:\n{self.stdout}") if self.stderr: logger.log(result_level, f"STDERR:\n{self.stderr}") class MultiprocessWorker: def __init__(self, logger: Logger = None, max_processes=4, spinner_set:Spinners = Spinners.BASIC): if logger is None: logger = Log(self.__class__.__name__, Log.Level.DEBUG).create_logger() self.logger = logger self.tasks: list[WorkerProcess] = [] self.max_processes = max_processes self.spinner:Spinner = spinner_set.value self.manager = Manager() self.active_processes = self.manager.Value('i', 0) self.stop_display = self.manager.Event() self._process_lock = threading.Lock() def add_tasks(self, instructions: list[WorkerProcess]): self.logger.log(Log.Level.DEBUG, f"Adding {len(instructions)} tasks.") self.tasks.extend(instructions) def add_task(self, instruction: WorkerProcess): self.logger.log(Log.Level.DEBUG, f"Adding task with callable {instruction.function.__name__}") self.tasks.append(instruction) def _spinner_display(self, total_tasks, results): idx = 0 list_length = len(self.spinner.char_list) while len(results) < total_tasks: sys.stdout.write(f"\r{self.spinner.char_list[idx]} | Completed: {len(results)}/{total_tasks} | Time: {datetime.now().strftime('%H:%M:%S')} ") sys.stdout.flush() idx = (idx + 1) % list_length time.sleep(self.spinner.delay) sys.stdout.write("\r") sys.stdout.flush() def run(self): self.logger.log(Log.Level.DEBUG, f"Starting with {len(self.tasks)} tasks and max {self.max_processes} processes.") results = [] spinner_thread = threading.Thread(target=self._spinner_display, args=(len(self.tasks), results)) spinner_thread.daemon = True spinner_thread.start() try: with Pool(self.max_processes) as pool: for wp in self.tasks: pool.apply_async( dispatch_worker, args=(wp,), callback=lambda res: results.append(res) ) pool.close() pool.join() except Exception as e: self.logger.exception("Unexpected error during multiprocessing run") finally: self.stop_display.set() spinner_thread.join() self.logger.log(Log.Level.DEBUG, "All multiprocessing tasks completed.") return results