130 lines
4.8 KiB
Python
130 lines
4.8 KiB
Python
from util.generic.log import Log
|
|
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
import sys
|
|
import time
|
|
from datetime import datetime
|
|
from multiprocessing import Pool, Manager
|
|
import threading
|
|
from logging import Logger
|
|
from typing import Optional
|
|
import re
|
|
|
|
@dataclass
|
|
class Spinner:
|
|
char_list: list[str]
|
|
delay: float
|
|
|
|
class Spinners(Enum):
|
|
BASIC = Spinner(["|", "/", "-", "\\"], 0.1)
|
|
SPIN_TRI_BLOCK = Spinner(["▙", "▛", "▜", "▟"], 0.1)
|
|
SPIN_RIGHT_ANGLE = Spinner(["🮤", "🮧", "🮥", "🮦"], 0.1)
|
|
SPIN_OPEN_CUBE = Spinner(["🮪", "🮫", "🮭", "🮬"], 0.1)
|
|
SPIN_DOTS = Spinner(["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"], 0.1)
|
|
GUY_DANCING = Spinner(["🯇", "🯈"], 0.3)
|
|
GUY_JUMPING = Spinner(["🯅", "🯆"], 0.3)
|
|
SHIFTING_PATTERN = Spinner(["🮕", "🮖"], 0.1)
|
|
SHIFTING_GRADIENT = Spinner(["░▒▓█▓▒░", "▒▓█▓▒░░", "▓█▓▒░▒▓", "█▓▒░▒▓█", "▓▒░▒▓█▓", "░░▒▓█▓▒"], 0.1)
|
|
|
|
@dataclass
|
|
class WorkerProcess:
|
|
function: callable
|
|
instructions: tuple
|
|
|
|
def dispatch_worker(worker_process: WorkerProcess):
|
|
try:
|
|
result = worker_process.function(worker_process.instructions)
|
|
return result
|
|
except Exception as e:
|
|
raise
|
|
|
|
@dataclass
|
|
class WorkerReturn:
|
|
status: bool
|
|
command: str
|
|
stdout: str = ""
|
|
stderr: str = ""
|
|
error: Optional[str] = None
|
|
|
|
def output_result(self, logger: Logger, success_regex: str = None, success_message: str = None):
|
|
if self.status:
|
|
result_level = Log.Level.INFO
|
|
if success_regex and success_message:
|
|
success_match = re.search(success_regex, self.command)
|
|
success_string = success_match.group(1) if success_match else "unknown_target"
|
|
logger.log(result_level, success_message.format(output = success_string))
|
|
else:
|
|
logger.log(result_level, f"SUCCESS: {self.command}")
|
|
else:
|
|
result_level = Log.Level.ERROR
|
|
logger.log(result_level, f"FAILURE: {self.command}")
|
|
if self.error:
|
|
logger.log(result_level, f"Error: {self.error}")
|
|
|
|
if self.stdout:
|
|
logger.log(result_level, f"STDOUT:\n{self.stdout}")
|
|
|
|
if self.stderr:
|
|
logger.log(result_level, f"STDERR:\n{self.stderr}")
|
|
|
|
class MultiprocessWorker:
|
|
def __init__(self, logger: Logger = None, max_processes=4, spinner_set:Spinners = Spinners.BASIC):
|
|
if logger is None: logger = Log(self.__class__.__name__, Log.Level.DEBUG).create_logger()
|
|
self.logger = logger
|
|
|
|
self.tasks: list[WorkerProcess] = []
|
|
self.max_processes = max_processes
|
|
self.spinner:Spinner = spinner_set.value
|
|
self.manager = Manager()
|
|
self.active_processes = self.manager.Value('i', 0)
|
|
self.stop_display = self.manager.Event()
|
|
self._process_lock = threading.Lock()
|
|
|
|
def add_tasks(self, instructions: list[WorkerProcess]):
|
|
self.logger.log(Log.Level.DEBUG, f"Adding {len(instructions)} tasks.")
|
|
self.tasks.extend(instructions)
|
|
|
|
def add_task(self, instruction: WorkerProcess):
|
|
self.logger.log(Log.Level.DEBUG, f"Adding task with callable {instruction.function.__name__}")
|
|
self.tasks.append(instruction)
|
|
|
|
def _spinner_display(self, total_tasks, results):
|
|
idx = 0
|
|
list_length = len(self.spinner.char_list)
|
|
|
|
while len(results) < total_tasks:
|
|
sys.stdout.write(f"\r{self.spinner.char_list[idx]} | Completed: {len(results)}/{total_tasks} | Time: {datetime.now().strftime('%H:%M:%S')} ")
|
|
sys.stdout.flush()
|
|
idx = (idx + 1) % list_length
|
|
time.sleep(self.spinner.delay)
|
|
|
|
sys.stdout.write("\r")
|
|
sys.stdout.flush()
|
|
|
|
def run(self):
|
|
self.logger.log(Log.Level.DEBUG, f"Starting with {len(self.tasks)} tasks and max {self.max_processes} processes.")
|
|
results = []
|
|
spinner_thread = threading.Thread(target=self._spinner_display, args=(len(self.tasks), results))
|
|
spinner_thread.daemon = True
|
|
spinner_thread.start()
|
|
|
|
try:
|
|
with Pool(self.max_processes) as pool:
|
|
for wp in self.tasks:
|
|
pool.apply_async(
|
|
dispatch_worker,
|
|
args=(wp,),
|
|
callback=lambda res: results.append(res)
|
|
)
|
|
pool.close()
|
|
pool.join()
|
|
except Exception as e:
|
|
self.logger.exception("Unexpected error during multiprocessing run")
|
|
finally:
|
|
self.stop_display.set()
|
|
spinner_thread.join()
|
|
|
|
self.logger.log(Log.Level.DEBUG, "All multiprocessing tasks completed.")
|
|
return results
|