diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e432a77..7783151 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,8 +22,15 @@ repos: rev: v5.8.0 hooks: - id: isort + args: [--profile=black, --line-length=120] - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.812 hooks: - id: mypy + + - repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.8.0 + hooks: + - id: python-check-blanket-noqa + - id: python-check-mock-methods diff --git a/README.md b/README.md index b538514..4335c6c 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,12 @@ # blocksync -[![Build](https://travis-ci.com/ehdgua01/blocksync.svg?branch=master)](https://travis-ci.com/github/ehdgua01/blocksync) -[![Coverage](https://codecov.io/gh/ehdgua01/blocksync/branch/master/graph/badge.svg)](https://app.codecov.io/gh/ehdgua01/blocksync) -[![PyPi](https://badge.fury.io/py/blocksync.svg)](https://pypi.org/project/blocksync/) -[![PyVersions](https://img.shields.io/pypi/pyversions/blocksync)](https://pypi.org/project/blocksync/) - Blocksync Python package allows [blocksync script](https://github.com/theraser/blocksync) to be used as Python packages, and supports more convenient and various functions than blocksync script. +[![Build](https://img.shields.io/travis/ehdgua01/blocksync/master.svg?style=for-the-badge&logo=travis)](https://travis-ci.com/github/ehdgua01/blocksync) +[![PyPi](https://img.shields.io/pypi/v/blocksync?logo=pypi&style=for-the-badge)](https://pypi.org/project/blocksync/) +[![PyVersions](https://img.shields.io/pypi/pyversions/blocksync?logo=python&style=for-the-badge)](https://pypi.org/project/blocksync/) + # Prerequisites - Python 3.8 or later @@ -15,7 +14,7 @@ and supports more convenient and various functions than blocksync script. # Features - Synchronize the destination (remote or local) files using an incremental algorithm. -- Supports all synchronization directions. (local-local, local-remote, remote-local, remote-remote) +- Supports all synchronization directions. (local-local, local-remote, remote-local) - Support for callbacks that can run before(run once or per workers), after(run once or per workers), and during synchronization of files - Support for synchronization suspend/resume, cancel. - Most methods support method chaining. @@ -31,23 +30,22 @@ pip install blocksync # Quick start -When using SFTP files, you can check the SSH connection options in [paramiko docs](http://docs.paramiko.org/en/stable/api/client.html#paramiko.client.SSHClient). +When sync from/to remote, you can check the SSH connection options in [paramiko docs](http://docs.paramiko.org/en/stable/api/client.html#paramiko.client.SSHClient). ```python -from blocksync import LocalFile, SFTPFile, Syncer - - -syncer = Syncer( - src=SFTPFile( - path="src.file", - hostname="hostname", - username="username", - password="password", - key_filename="key_filepath", - ), - dest=LocalFile(path="dest.file"), -) -syncer.start_sync(workers=2, create=True, wait=True) +from blocksync import local_to_local + + +manager, status = local_to_local("src.txt", "dest.txt", workers=4) +manager.wait_sync() +print(status) + +# Output +[Worker 1]: Start sync(src.txt -> dest.txt) 1 blocks +[Worker 2]: Start sync(src.txt -> dest.txt) 1 blocks +[Worker 3]: Start sync(src.txt -> dest.txt) 1 blocks +[Worker 4]: Start sync(src.txt -> dest.txt) 1 blocks +{'workers': 4, 'chunk_size': 250, 'block_size': 250, 'src_size': 1000, 'dest_size': 1000, 'blocks': {'same': 4, 'diff': 0, 'done': 4}} ``` diff --git a/blocksync/__init__.py b/blocksync/__init__.py index 7f67fe8..512274a 100644 --- a/blocksync/__init__.py +++ b/blocksync/__init__.py @@ -1,4 +1,5 @@ -from blocksync.files import File, LocalFile, SFTPFile -from blocksync.syncer import Syncer +from blocksync._status import Status +from blocksync._sync_manager import SyncManager +from blocksync.sync import local_to_local, local_to_remote, remote_to_local -__all__ = ["File", "LocalFile", "SFTPFile", "Syncer"] +__all__ = ["local_to_local", "local_to_remote", "remote_to_local", "Status", "SyncManager"] diff --git a/blocksync/_consts.py b/blocksync/_consts.py new file mode 100644 index 0000000..ac6c23b --- /dev/null +++ b/blocksync/_consts.py @@ -0,0 +1,37 @@ +import re +from pathlib import Path + +__all__ = ["BASE_DIR", "ByteSizes", "SAME", "SKIP", "DIFF"] + +BASE_DIR = Path(__file__).parent +SAME: str = "0" +SKIP: str = "1" +DIFF: str = "2" + + +class ByteSizes: + BLOCK_SIZE_PATTERN = re.compile("([0-9]+)(B|KB|MB|GB|KiB|K|MiB|M|GiB|G)") + + B: int = 1 + KB: int = 1000 + MB: int = 1000 ** 2 + GB: int = 1000 ** 3 + KiB: int = 1 << 10 + K = KiB + MiB: int = 1 << 20 + M = MiB + GiB: int = 1 << 30 + G = GiB + + @classmethod + def parse_readable_byte_size(cls, size: str) -> int: + """ + Examples + 1MB -> 1000000 + 1M, 1MiB -> 10478576 + """ + if not size.isdigit(): + if matched := cls.BLOCK_SIZE_PATTERN.match(size): + size, unit = matched.group(1), matched.group(2).strip() + return int(size) * getattr(ByteSizes, unit.upper()) + return int(size) diff --git a/blocksync/_hooks.py b/blocksync/_hooks.py new file mode 100644 index 0000000..b76ed8c --- /dev/null +++ b/blocksync/_hooks.py @@ -0,0 +1,35 @@ +from typing import Any, Callable, Optional + +from blocksync._status import Status + +__all__ = ["Hooks"] + + +class Hooks: + def __init__( + self, + on_before: Optional[Callable[..., Any]], + on_after: Optional[Callable[[Status], Any]], + monitor: Optional[Callable[[Status], Any]], + on_error: Optional[Callable[[Exception, Status], Any]], + ): + self.before: Optional[Callable[..., Any]] = on_before + self.after: Optional[Callable[[Status], Any]] = on_after + self.monitor: Optional[Callable[[Status], Any]] = monitor + self.on_error: Optional[Callable[[Exception, Status], Any]] = on_error + + def _run(self, hook: Optional[Callable], *args, **kwargs): + if hook: + hook(*args, **kwargs) + + def run_before(self): + self._run(self.before) + + def run_after(self, status: Status): + self._run(self.after, status) + + def run_monitor(self, status: Status): + self._run(self.monitor, status) + + def run_on_error(self, exc: Exception, status: Status): + self._run(self.on_error, exc, status) diff --git a/blocksync/_read_server.py b/blocksync/_read_server.py new file mode 100644 index 0000000..ebd4308 --- /dev/null +++ b/blocksync/_read_server.py @@ -0,0 +1,29 @@ +import hashlib +import io +import sys +from typing import Callable + +DIFF = b"2" +COMPLEN = len(DIFF) +path: bytes = sys.stdin.buffer.readline().strip() +stdout = sys.stdout.buffer +stdin = sys.stdin.buffer + +fileobj = open(path, "rb") +fileobj.seek(io.SEEK_SET, io.SEEK_END) +print(fileobj.tell(), flush=True) + +block_size: int = int(stdin.readline()) +hash_: Callable = getattr(hashlib, stdin.readline().strip().decode()) +startpos: int = int(stdin.readline()) +maxblock: int = int(stdin.readline()) + +with fileobj: + fileobj.seek(startpos) + for _ in range(maxblock): + block = fileobj.read(block_size) + stdout.write(hash_(block).digest()) + stdout.flush() + if stdin.read(COMPLEN) == DIFF: + stdout.write(block) + stdout.flush() diff --git a/blocksync/_status.py b/blocksync/_status.py new file mode 100644 index 0000000..e78dfde --- /dev/null +++ b/blocksync/_status.py @@ -0,0 +1,41 @@ +import threading +from typing import Literal, TypedDict + + +class Blocks(TypedDict): + same: int + diff: int + done: int + + +class Status: + def __init__( + self, + workers: int, + block_size: int, + src_size: int, + dest_size: int = 0, + ): + self._lock = threading.Lock() + self.workers: int = workers + self.chunk_size: int = src_size // workers + self.block_size: int = block_size + self.src_size: int = src_size + self.dest_size: int = dest_size + self.blocks: Blocks = Blocks(same=0, diff=0, done=0) + + def __repr__(self): + return str({k: v for k, v in self.__dict__.items() if k != "_lock"}) + + def add_block(self, block_type: Literal["same", "diff"]): + with self._lock: + self.blocks[block_type] += 1 + self.blocks["done"] = self.blocks["same"] + self.blocks["diff"] + + @property + def rate(self) -> float: + return ( + min(100.00, (self.blocks["done"] / (self.src_size // self.block_size)) * 100) + if self.blocks["done"] > 1 + else 0.00 + ) diff --git a/blocksync/_sync_manager.py b/blocksync/_sync_manager.py new file mode 100644 index 0000000..92aa18a --- /dev/null +++ b/blocksync/_sync_manager.py @@ -0,0 +1,41 @@ +import threading +from typing import List + + +class SyncManager: + def __init__(self): + self.workers: List[threading.Thread] = [] + self._suspend: threading.Event = threading.Event() + self._suspend.set() + self._cancel: bool = False + + def cancel_sync(self): + self._cancel = True + + def wait_sync(self): + for worker in self.workers: + worker.join() + + def suspend(self): + self._suspend.clear() + + def resume(self): + self._suspend.set() + + def _wait_resuming(self): + self._suspend.wait() + + @property + def canceled(self) -> bool: + return self._cancel + + @property + def suspended(self) -> bool: + return not self._suspend.is_set() + + @property + def finished(self) -> bool: + for worker in self.workers: + if worker.is_alive(): + return False + return True diff --git a/blocksync/_write_server.py b/blocksync/_write_server.py new file mode 100644 index 0000000..4166c33 --- /dev/null +++ b/blocksync/_write_server.py @@ -0,0 +1,25 @@ +import io +import sys + +DIFF = b"2" +COMPLEN = len(DIFF) +stdin = sys.stdin.buffer + +path = stdin.readline().strip() + +size = int(stdin.readline()) +if size > 0: + with open(path, "a+") as fileobj: + fileobj.truncate(size) + +block_size = int(stdin.readline()) +startpos = int(stdin.readline()) +maxblock = int(stdin.readline()) + +with open(path, mode="rb+") as f: + f.seek(startpos) + for _ in range(maxblock): + if stdin.read(COMPLEN) == DIFF: + f.write(stdin.read(block_size)) + else: + f.seek(block_size, io.SEEK_CUR) diff --git a/blocksync/consts.py b/blocksync/consts.py deleted file mode 100644 index 9bb9609..0000000 --- a/blocksync/consts.py +++ /dev/null @@ -1,12 +0,0 @@ -__all__ = ["ByteSizes", "SSH_PORT"] - -SSH_PORT = 22 - - -class ByteSizes: - KB: int = 1000 - MB: int = 1000 ** 2 - GB: int = 1000 ** 3 - KiB: int = 1 << 10 - MiB: int = 1 << 20 - GiB: int = 1 << 30 diff --git a/blocksync/files/__init__.py b/blocksync/files/__init__.py deleted file mode 100644 index 63d02ae..0000000 --- a/blocksync/files/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from blocksync.files.interfaces import File -from blocksync.files.local_file import LocalFile -from blocksync.files.sftp_file import SFTPFile - -__all__ = ["File", "LocalFile", "SFTPFile"] diff --git a/blocksync/files/interfaces.py b/blocksync/files/interfaces.py deleted file mode 100644 index b35cf2b..0000000 --- a/blocksync/files/interfaces.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -import abc -import os -import threading -from pathlib import Path -from typing import IO, Generator, Optional, Union - -import paramiko - -from blocksync.consts import ByteSizes - -__all__ = ["File"] - - -class LocalThreadVars(threading.local): - io: Optional[Union[IO, paramiko.SFTPFile]] = None - - -class File(abc.ABC): - def __init__(self, path: Union[Path, str]): - self._local: LocalThreadVars = LocalThreadVars() - self.path: Union[Path, str] = path - self.size: int = 0 - - def __repr__(self): - return f"<{self.__class__.__name__} path={self.path} opened={self.opened}>" - - def do_close(self, flush: bool = True) -> File: - if self.opened: - if flush: - self.io.flush() # type: ignore[union-attr] - self.io.close() # type: ignore[union-attr] - return self - - def do_create(self, size: int) -> File: - with self._open(mode="w") as f: - f.truncate(size) - return self - - def do_open(self) -> File: - fileobj = self._open(mode="rb+") - self._local.io = fileobj - self.size = self._get_size(fileobj) - return self - - def get_blocks(self, block_size: int = ByteSizes.MiB) -> Generator[bytes, None, None]: - while self.opened and (block := self.get_block(block_size)): - yield block - - def get_block(self, block_size: int = ByteSizes.MiB) -> Optional[bytes]: - return self.io.read(block_size) if self.opened and self.io else None - - @abc.abstractmethod - def _open(self, mode: str) -> Union[IO, paramiko.SFTPFile]: - raise NotImplementedError - - def _get_size(self, fileobj: Union[IO, paramiko.SFTPFile]) -> int: - fileobj.seek(os.SEEK_SET, os.SEEK_END) - size = fileobj.tell() - fileobj.seek(os.SEEK_SET) - return size - - @property - def io(self) -> Optional[Union[IO, paramiko.SFTPFile]]: - return self._local.io - - @property - def opened(self) -> bool: - if io := self.io: - return not io.closed - return False diff --git a/blocksync/files/local_file.py b/blocksync/files/local_file.py deleted file mode 100644 index 0d55c52..0000000 --- a/blocksync/files/local_file.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations - -import io -from typing import IO - -from blocksync.files.interfaces import File - -__all__ = ["LocalFile"] - - -class LocalFile(File): - def _open(self, mode: str) -> IO: - return io.open(self.path, mode=mode) diff --git a/blocksync/files/sftp_file.py b/blocksync/files/sftp_file.py deleted file mode 100644 index 11e8a4e..0000000 --- a/blocksync/files/sftp_file.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -import stat -from pathlib import Path -from typing import Union - -import paramiko - -from blocksync.files.interfaces import File - -__all__ = ["SFTPFile"] - - -class SFTPFile(File): - def __init__(self, path: Union[Path, str], **ssh_options): - super().__init__(path) - ssh_client = paramiko.SSHClient() - ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy) - ssh_client.load_system_host_keys() - ssh_client.connect(**ssh_options) - self._ssh: paramiko.SSHClient = ssh_client - - def do_close(self, flush=True) -> SFTPFile: - if self.ssh_connected: - super().do_close(flush) - return self - - def _open(self, mode: str) -> paramiko.SFTPFile: - if self.ssh_connected and (sftp := self._ssh.open_sftp()): - return sftp.open(self.path if isinstance(self.path, str) else str(self.path), mode=mode) - raise ValueError("Cannot open the remote file. Please connect paramiko.SSHClient") - - def _get_size(self, fileobj: paramiko.SFTPFile) -> int: # type: ignore[override] - size = super()._get_size(fileobj) - if size == 0 and stat.S_ISBLK(fileobj.stat().st_mode): # type: ignore[arg-type] - stdin, stdout, stderr = self._ssh.exec_command( - f"""python -c "with open('{self.path}', 'r') as f: f.seek(0, 2); print(f.tell())" """, - ) - return int(stdout.read()) - return size - - @property - def ssh_connected(self) -> bool: - if transport := self._ssh.get_transport(): - return transport.is_active() - return False diff --git a/blocksync/hooks.py b/blocksync/hooks.py deleted file mode 100644 index 79687d8..0000000 --- a/blocksync/hooks.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Any, Callable, Optional - -from blocksync.status import Status - -__all__ = ["Hooks"] - - -class Hooks: - def __init__(self): - self.root_before: Optional[Callable[[None], Any]] = None - self.before: Optional[Callable[[None], Any]] = None - self.root_after: Optional[Callable[[Status], Any]] = None - self.after: Optional[Callable[[Status], Any]] = None - self.monitor: Optional[Callable[[Status], Any]] = None - self.on_error: Optional[Callable[[Status], Any]] = None - - def _run(self, hook: Optional[Callable], *args, **kwargs): - if hook: - hook(*args, **kwargs) - - def run_root_before(self): - self._run(self.root_before) - - def run_before(self): - self._run(self.before) - - def run_root_after(self, status: Status): - self._run(self.root_after, status) - - def run_after(self, status: Status): - self._run(self.after, status) - - def run_monitor(self, status: Status): - self._run(self.monitor, status) - - def run_on_error(self, exc: Exception, status: Status): - self._run(self.on_error, exc, status) diff --git a/blocksync/status.py b/blocksync/status.py deleted file mode 100644 index 39c008e..0000000 --- a/blocksync/status.py +++ /dev/null @@ -1,50 +0,0 @@ -import threading -from typing import Literal, TypedDict - -from blocksync.consts import ByteSizes - - -class Blocks(TypedDict): - same: int - diff: int - done: int - - -class Status: - def __init__(self): - self._lock = threading.Lock() - self.block_size: int = ByteSizes.MiB - self._source_size: int = 0 - self._destination_size: int = 0 - self._blocks: Blocks = Blocks(same=0, diff=0, done=0) - - def initialize(self, /, block_size: int = ByteSizes.MiB, source_size: int = 0, destination_size: int = 0): - self.block_size = block_size - self._source_size = source_size - self._destination_size = destination_size - self._blocks = Blocks(same=0, diff=0, done=0) - - def _add_block(self, block_type: Literal["same", "diff"]): - with self._lock: - self._blocks[block_type] += 1 - self._blocks["done"] = self._blocks["same"] + self._blocks["diff"] - - @property - def source_size(self) -> int: - return self._source_size - - @property - def destination_size(self) -> int: - return self._destination_size - - @property - def blocks(self) -> Blocks: - return self._blocks - - @property - def rate(self) -> float: - return ( - min(100.00, (self._blocks["done"] / (self._source_size // self.block_size)) * 100) - if self._blocks["done"] > 1 - else 0.00 - ) diff --git a/blocksync/sync.py b/blocksync/sync.py new file mode 100644 index 0000000..731e8f8 --- /dev/null +++ b/blocksync/sync.py @@ -0,0 +1,444 @@ +import hashlib +import io +import logging +import threading +import time +import timeit +from math import ceil +from typing import IO, Any, Callable, Dict, Generator, Optional, Tuple, Union + +import paramiko + +from blocksync._consts import BASE_DIR, DIFF, SKIP, ByteSizes +from blocksync._hooks import Hooks +from blocksync._status import Status +from blocksync._sync_manager import SyncManager + +__all__ = ["local_to_local", "local_to_remote", "remote_to_local"] + +READ_SERVER_SCRIPT_NAME = "_read_server.py" +DEFAULT_READ_SERVER_SCRIPT_PATH = str((BASE_DIR / READ_SERVER_SCRIPT_NAME).resolve()) +WRITE_SERVER_SCRIPT_NAME = "_write_server.py" +DEFAULT_WRITE_SERVER_SCRIPT_PATH = str((BASE_DIR / WRITE_SERVER_SCRIPT_NAME).resolve()) + +logger = logging.getLogger("blocksync") +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) + + +def _get_block_size(block_size: Union[int, str]) -> int: + if isinstance(block_size, str): + return ByteSizes.parse_readable_byte_size(block_size) + return block_size + + +def _get_range(worker_id: int, status: Status) -> Tuple[int, int]: + start = status.chunk_size * (worker_id - 1) + chunk_size = status.chunk_size + if worker_id == status.workers: + chunk_size += status.src_size % status.workers + return start, ceil(chunk_size / status.block_size) + + +def _get_size(path: str) -> int: + fileobj = open(path, "r") + fileobj.seek(io.SEEK_SET, io.SEEK_END) + size: int = fileobj.tell() + fileobj.seek(io.SEEK_SET) + return size + + +def _get_remotedev_size(ssh: paramiko.SSHClient, command: str, path: str) -> int: + stdin, stdout, _ = ssh.exec_command(command) + try: + stdin.write(f"{path}\n") + return int(stdout.readline()) + finally: + stdout.close() + stdin.close() + + +def _do_create(path: str, size: int): + with open(path, "a+") as fileobj: + fileobj.truncate(size) + + +def _get_blocks(fileobj: IO, block_size: int) -> Generator[bytes, None, None]: + while block := fileobj.read(block_size): + yield block + + +def _log(worker_id: int, msg: str, level: int = logging.INFO, *args, **kwargs): + logger.log(level, f"[Worker {worker_id}]: {msg}", *args, **kwargs) + + +def _connect_ssh( + allow_load_system_host_keys: bool = True, + compress: bool = True, + **ssh_config, +) -> paramiko.SSHClient: + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy) + if allow_load_system_host_keys: + ssh.load_system_host_keys() + ssh.connect(**ssh_config, compress=compress) + return ssh + + +def _sync( + manager: SyncManager, + status: Status, + workers: int, + sync: Callable, + sync_options: Dict[str, Any], + wait: bool = False, +) -> Tuple[Optional[SyncManager], Status]: + for i in range(1, workers + 1): + sync_options["worker_id"] = i + worker = threading.Thread(target=sync, kwargs=sync_options) + worker.start() + manager.workers.append(worker) + if wait: + manager.wait_sync() + return None, status + return manager, status + + +def local_to_local( + src: str, + dest: str, + block_size: Union[str, int] = ByteSizes.MiB, + workers: int = 1, + create_dest: bool = False, + wait: bool = False, + dryrun: bool = False, + on_before: Optional[Callable[..., Any]] = None, + on_after: Optional[Callable[[Status], Any]] = None, + monitor: Optional[Callable[[Status], Any]] = None, + on_error: Optional[Callable[[Exception, Status], Any]] = None, + monitoring_interval: Union[int, float] = 1, + sync_interval: Union[int, float] = 0, +) -> Tuple[Optional[SyncManager], Status]: + status = Status( + workers=workers, + block_size=_get_block_size(block_size), + src_size=_get_size(src), + ) + if create_dest: + _do_create(dest, status.src_size) + status.dest_size = _get_size(dest) + manager = SyncManager() + sync_options = { + "src": src, + "dest": dest, + "status": status, + "manager": manager, + "hooks": Hooks(on_before=on_before, on_after=on_after, monitor=monitor, on_error=on_error), + "dryrun": dryrun, + "monitoring_interval": monitoring_interval, + "sync_interval": sync_interval, + } + return _sync(manager, status, workers, _local_to_local, sync_options, wait) + + +def _local_to_local( + worker_id: int, + src: str, + dest: str, + status: Status, + manager: SyncManager, + hooks: Hooks, + dryrun: bool, + monitoring_interval: Union[int, float], + sync_interval: Union[int, float], +): + hooks.run_before() + + startpos, maxblock = _get_range(worker_id, status) + _log(worker_id, f"Start sync({src} -> {dest}) {maxblock} blocks") + srcdev = io.open(src, "rb+") + destdev = io.open(dest, "rb+") + srcdev.seek(startpos) + destdev.seek(startpos) + + t_last = timeit.default_timer() + try: + for src_block, dest_block, _ in zip( + _get_blocks(srcdev, status.block_size), + _get_blocks(destdev, status.block_size), + range(maxblock), + ): + if manager.suspended: + _log(worker_id, "Wait resuming...") + manager._wait_resuming() + if manager.canceled: + break + + if src_block != dest_block: + if not dryrun: + destdev.seek(-len(src_block), io.SEEK_CUR) + srcdev.write(src_block) + destdev.flush() + status.add_block("diff") + else: + status.add_block("same") + + t_cur = timeit.default_timer() + if monitoring_interval <= t_cur - t_last: + hooks.run_monitor(status) + t_last = t_cur + if 0 < sync_interval: + time.sleep(sync_interval) + except Exception as e: + _log(worker_id, msg=str(e), exc_info=True) + hooks.run_on_error(e, status) + finally: + srcdev.close() + destdev.close() + hooks.run_after(status) + + +def local_to_remote( + src: str, + dest: str, + block_size: Union[str, int] = ByteSizes.MiB, + workers: int = 1, + create_dest: bool = False, + wait: bool = False, + dryrun: bool = False, + on_before: Optional[Callable[..., Any]] = None, + on_after: Optional[Callable[[Status], Any]] = None, + monitor: Optional[Callable[[Status], Any]] = None, + on_error: Optional[Callable[[Exception, Status], Any]] = None, + monitoring_interval: Union[int, float] = 1, + sync_interval: Union[int, float] = 0, + hash1: str = "sha256", + read_server_command: Optional[str] = None, + write_server_command: Optional[str] = None, + allow_load_system_host_keys: bool = True, + compress: bool = True, + **ssh_config, +) -> Tuple[Optional[SyncManager], Status]: + status: Status = Status( + workers=workers, + block_size=_get_block_size(block_size), + src_size=_get_size(src), + ) + + ssh = _connect_ssh(allow_load_system_host_keys, compress, **ssh_config) + if sftp := ssh.open_sftp(): + if read_server_command is None: + sftp.put(DEFAULT_READ_SERVER_SCRIPT_PATH, READ_SERVER_SCRIPT_NAME) + read_server_command = f"python3 {READ_SERVER_SCRIPT_NAME}" + if write_server_command is None: + sftp.put(DEFAULT_WRITE_SERVER_SCRIPT_PATH, WRITE_SERVER_SCRIPT_NAME) + write_server_command = f"python3 {WRITE_SERVER_SCRIPT_NAME}" + + manager = SyncManager() + sync_options = { + "ssh": ssh, + "src": src, + "dest": dest, + "status": status, + "manager": manager, + "create_dest": create_dest, + "dryrun": dryrun, + "hooks": Hooks(on_before=on_before, on_after=on_after, monitor=monitor, on_error=on_error), + "monitoring_interval": monitoring_interval, + "sync_interval": sync_interval, + "hash1": hash1, + "read_server_command": read_server_command, + "write_server_command": write_server_command, + } + return _sync(manager, status, workers, _local_to_remote, sync_options, wait) + + +def _local_to_remote( + worker_id: int, + ssh: paramiko.SSHClient, + src: str, + dest: str, + status: Status, + manager: SyncManager, + create_dest: bool, + dryrun: bool, + hooks: Hooks, + monitoring_interval: Union[int, float], + sync_interval: Union[int, float], + hash1: str, + read_server_command: str, + write_server_command: str, +): + hash_ = getattr(hashlib, hash1) + hash_len = hash_().digest_size + + hooks.run_before() + + reader_stdin, reader_stdout, _ = ssh.exec_command(read_server_command) + writer_stdin, writer_stdout, _ = ssh.exec_command(write_server_command) + writer_stdin.write(f"{dest}\n{status.src_size if create_dest else 0}\n") + reader_stdin.write(f"{dest}\n") + status.dest_size = int(reader_stdout.readline()) + startpos, maxblock = _get_range(worker_id, status) + _log(worker_id, f"Start sync({src} -> {dest}) {maxblock} blocks") + reader_stdin.write(f"{status.block_size}\n{hash1}\n{startpos}\n{maxblock}\n") + writer_stdin.write(f"{status.block_size}\n{startpos}\n{maxblock}\n") + + t_last = timeit.default_timer() + with open(src, "rb+") as fileobj: + fileobj.seek(startpos) + try: + for src_block, _ in zip(_get_blocks(fileobj, status.block_size), range(maxblock)): + if manager.suspended: + _log(worker_id, "Wait resuming...") + manager._wait_resuming() + if manager.canceled: + break + + src_block_hash: bytes = hash_(src_block).digest() + dest_block_hash: bytes = reader_stdout.read(hash_len) + reader_stdin.write(SKIP) + if src_block_hash != dest_block_hash: + if not dryrun: + writer_stdin.write(DIFF) + writer_stdin.write(src_block) + else: + writer_stdin.write(SKIP) + status.add_block("diff") + else: + status.add_block("same") + + t_cur = timeit.default_timer() + if monitoring_interval <= t_cur - t_last: + hooks.run_monitor(status) + t_last = t_cur + if 0 < sync_interval: + time.sleep(sync_interval) + except Exception as e: + _log(worker_id, msg=str(e), exc_info=True) + hooks.run_on_error(e, status) + finally: + reader_stdin.close() + reader_stdout.close() + writer_stdin.close() + writer_stdout.close() + hooks.run_after(status) + + +def remote_to_local( + src: str, + dest: str, + block_size: Union[str, int] = ByteSizes.MiB, + workers: int = 1, + create_dest: bool = False, + wait: bool = False, + dryrun: bool = False, + on_before: Optional[Callable[..., Any]] = None, + on_after: Optional[Callable[[Status], Any]] = None, + monitor: Optional[Callable[[Status], Any]] = None, + on_error: Optional[Callable[[Exception, Status], Any]] = None, + monitoring_interval: Union[int, float] = 1, + sync_interval: Union[int, float] = 0, + hash1: str = "sha256", + allow_load_system_host_keys: bool = True, + compress: bool = True, + read_server_command: Optional[str] = None, + **ssh_config, +): + ssh = _connect_ssh(allow_load_system_host_keys, compress, **ssh_config) + if read_server_command is None and (sftp := ssh.open_sftp()): + sftp.put(DEFAULT_READ_SERVER_SCRIPT_PATH, READ_SERVER_SCRIPT_NAME) + read_server_command = f"python3 {READ_SERVER_SCRIPT_NAME}" + + status = Status( + workers=workers, + block_size=ByteSizes.parse_readable_byte_size(block_size) if isinstance(block_size, str) else block_size, + src_size=_get_remotedev_size(ssh, read_server_command, src), # type: ignore[arg-type] + ) + if create_dest: + _do_create(dest, status.src_size) + status.dest_size = _get_size(dest) + manager = SyncManager() + sync_options = { + "ssh": ssh, + "src": src, + "dest": dest, + "status": status, + "manager": manager, + "dryrun": dryrun, + "hooks": Hooks(on_before=on_before, on_after=on_after, monitor=monitor, on_error=on_error), + "monitoring_interval": monitoring_interval, + "sync_interval": sync_interval, + "hash1": hash1, + "read_server_command": read_server_command, + } + return _sync(manager, status, workers, _remote_to_local, sync_options, wait) + + +def _remote_to_local( + worker_id: int, + ssh: paramiko.SSHClient, + src: str, + dest: str, + status: Status, + manager: SyncManager, + dryrun: bool, + monitoring_interval: Union[int, float], + sync_interval: Union[int, float], + hash1: str, + read_server_command: str, + hooks: Hooks, +): + hash_ = getattr(hashlib, hash1) + hash_len = hash_().digest_size + + hooks.run_before() + + reader_stdin, _, _ = ssh.exec_command(read_server_command) + reader_stdout = reader_stdin.channel.makefile("rb") + reader_stdin.write(f"{src}\n") + reader_stdout.readline() + startpos, maxblock = _get_range(worker_id, status) + _log(worker_id, f"Start sync({src} -> {dest}) {maxblock} blocks") + reader_stdin.write(f"{status.block_size}\n{hash1}\n{startpos}\n{maxblock}\n") + + t_last = timeit.default_timer() + with open(dest, "rb+") as fileobj: + fileobj.seek(startpos) + try: + for dest_block, _ in zip(_get_blocks(fileobj, status.block_size), range(maxblock)): + if manager.suspended: + _log(worker_id, "Wait resuming...") + manager._wait_resuming() + if manager.canceled: + break + + src_block_hash: bytes = reader_stdout.read(hash_len) + dest_block_hash: bytes = hash_(dest_block).digest() + if src_block_hash != dest_block_hash: + if not dryrun: + reader_stdin.write(DIFF) + src_block = reader_stdout.read(status.block_size) + fileobj.seek(-len(src_block), io.SEEK_CUR) + fileobj.write(src_block) + fileobj.flush() + else: + reader_stdin.write(SKIP) + status.add_block("diff") + else: + reader_stdin.write(SKIP) + status.add_block("same") + + t_cur = timeit.default_timer() + if monitoring_interval <= t_cur - t_last: + hooks.run_monitor(status) + t_last = t_cur + if 0 < sync_interval: + time.sleep(sync_interval) + except Exception as e: + _log(worker_id, msg=str(e), exc_info=True) + hooks.run_on_error(e, status) + finally: + reader_stdin.close() + reader_stdout.close() + hooks.run_after(status) diff --git a/blocksync/syncer.py b/blocksync/syncer.py deleted file mode 100644 index e1225fe..0000000 --- a/blocksync/syncer.py +++ /dev/null @@ -1,142 +0,0 @@ -from __future__ import annotations - -import logging -import logging.handlers -import os -import threading -from typing import List, Tuple, Union - -from blocksync.consts import ByteSizes -from blocksync.files.interfaces import File -from blocksync.hooks import Hooks -from blocksync.status import Status -from blocksync.worker import Worker - -__all__ = ["Syncer"] - -blocksync_logger = logging.getLogger(__name__) -blocksync_logger.setLevel(logging.INFO) -blocksync_logger.addHandler(logging.StreamHandler()) - - -class Syncer: - def __init__(self, src: File, dest: File) -> None: - self.src = src - self.dest = dest - - self.status: Status = Status() - self.hooks: Hooks = Hooks() - - self.logger: logging.Logger = blocksync_logger - - self._started: bool = False - self._canceled: bool = False - self._suspended: threading.Event = threading.Event() - self._workers: List[threading.Thread] = [] - - def __repr__(self): - return f"" - - def start_sync( - self, - workers: int = 1, - block_size: int = ByteSizes.MiB, - wait: bool = False, - dryrun: bool = False, - create: bool = False, - sync_interval: Union[float, int] = 0.1, - monitoring_interval: Union[float, int] = 1, - ) -> Syncer: - if workers < 1: - raise ValueError("Workers must be greater than 1") - self._canceled = False - self._suspended.set() - self._pre_sync(create) - self.status.initialize(block_size=block_size, source_size=self.src.size, destination_size=self.dest.size) - self._workers = [] - for i in range(1, workers + 1): - startpos, endpos = self._get_positions(workers, i) - worker = Worker( - worker_id=i, - syncer=self, - startpos=startpos, - endpos=endpos, - dryrun=dryrun, - sync_interval=sync_interval, - monitoring_interval=monitoring_interval, - logger=self.logger, - ) - worker.start() - self._workers.append(worker) - self._started = True - if wait: - self.wait() - return self - - def _pre_sync(self, create: bool = False): - self.src.do_open() - try: - self.dest.do_open() - except FileNotFoundError: - if not create: - raise - self.dest.do_create(self.src.size).do_open() - if self.src.size > self.dest.size: - self.logger.warning(f"Source size({self.src.size}) is greater than destination size({self.dest.size})") - elif self.src.size < self.dest.size: - self.logger.info(f"Source size({self.src.size}) is less than destination size({self.dest.size})") - - def _get_positions(self, workers: int, worker_id: int) -> Tuple[int, int]: - chunk_size = self.src.size // workers - start = os.SEEK_SET - end = chunk_size * worker_id - if 1 < worker_id: - start = chunk_size * (worker_id - 1) - if worker_id == workers: - end += self.src.size % workers - return start, end - - def wait(self) -> Syncer: - for worker in self._workers: - if worker.is_alive(): - worker.join() - return self - - def cancel(self) -> Syncer: - self._canceled = True - return self - - def suspend(self) -> Syncer: - if self._suspended.is_set(): - self._suspended.clear() - return self - - def resume(self) -> Syncer: - if not self._suspended.is_set(): - self._suspended.set() - return self - - @property - def workers(self) -> List[threading.Thread]: - return self._workers - - @property - def suspended(self) -> bool: - return not self._suspended.is_set() - - @property - def canceled(self) -> bool: - return self._canceled - - @property - def started(self) -> bool: - return self._started - - @property - def finished(self) -> bool: - if not self._started: - return False - for worker in self._workers: - if worker.is_alive(): - return False - return True diff --git a/blocksync/worker.py b/blocksync/worker.py deleted file mode 100644 index b9c1a60..0000000 --- a/blocksync/worker.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -import logging -import os -import threading -import time -from timeit import default_timer as timer -from typing import TYPE_CHECKING, Union - -if TYPE_CHECKING: - from blocksync import Syncer - -__all__ = ["Worker"] - - -class Worker(threading.Thread): - def __init__( - self, - worker_id: int, - syncer: Syncer, - startpos: int, - endpos: int, - dryrun: bool, - sync_interval: Union[int, float], - monitoring_interval: Union[int, float], - logger: logging.Logger, - ): - super().__init__() - self.worker_id = worker_id - self.syncer: Syncer = syncer - self.startpos: int = startpos - self.endpos: int = endpos - self.dryrun: bool = dryrun - self.sync_interval = sync_interval - self.monitoring_interval = monitoring_interval - self.logger: logging.Logger = logger - - def run(self): - self.syncer.hooks.run_root_before() - try: - self._sync() - self.syncer.hooks.run_root_after(self.syncer.status) - finally: - self.syncer.src.do_close() - self.syncer.dest.do_close() - - def _sync(self): - self.syncer.src.do_open().io.seek(self.startpos) # type: ignore[union-attr] - self.syncer.dest.do_open().io.seek(self.startpos) # type: ignore[union-attr] - - self._log( - f"Start sync(startpos: {self.startpos}, endpos: {self.endpos}) {self.syncer.src} to {self.syncer.dest}" - ) - self.syncer.hooks.run_before() - - t_last = timer() - try: - for source_block, dest_block in zip( - self.syncer.src.get_blocks(self.syncer.status.block_size), - self.syncer.dest.get_blocks(self.syncer.status.block_size), - ): - if self.syncer.suspended: - self._log("Suspending...") - self.syncer._suspended.wait() - if self.syncer.canceled: - self._log("Synchronization task has been canceled") - return - if source_block == dest_block: - self.syncer.status._add_block("same") - else: - self.syncer.status._add_block("diff") - if not self.dryrun: - offset = min(len(source_block), len(dest_block), self.syncer.status.block_size) - self.syncer.dest.io.seek(-offset, os.SEEK_CUR) # type: ignore[union-attr] - self.syncer.dest.io.write(source_block) # type: ignore[union-attr] - self.syncer.dest.io.flush() # type: ignore[union-attr] - t_cur = timer() - if self.monitoring_interval <= t_cur - t_last: - self.syncer.hooks.run_monitor(self.syncer.status) - t_last = t_cur - if self.endpos <= self.syncer.src.io.tell(): # type: ignore[union-attr] - self._log("!!! Done !!!") - break - if 0 < self.sync_interval: - time.sleep(self.sync_interval) - self.syncer.hooks.run_after(self.syncer.status) - except Exception as e: - self._log(str(e), level=logging.ERROR, exc_info=True) - self.syncer.hooks.run_on_error(e, self.syncer.status) - - def _log(self, msg: str, level: int = logging.INFO, *args, **kwargs): - self.logger.log(level, f"[Worker {self.worker_id}]: {msg}", *args, **kwargs) diff --git a/setup.cfg b/setup.cfg index dd3d882..5e42705 100644 --- a/setup.cfg +++ b/setup.cfg @@ -68,6 +68,8 @@ omit = venv/* tests/* setup.py + */_read_server.py + */_write_server.py [mypy] python_version = 3.8 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..77705ac --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,25 @@ +import pytest + +from blocksync._status import Status + +pytest_plugins = "pytester" + + +@pytest.fixture +def fake_status(): + return Status( + workers=2, + block_size=500, + src_size=1_000, + dest_size=1_000, + ) + + +@pytest.fixture(scope="session") +def source_content(): + return b"source content" + + +@pytest.fixture +def source_file(pytester, source_content): + return pytester.makefile(".img", source_content) diff --git a/tests/test_consts.py b/tests/test_consts.py new file mode 100644 index 0000000..c6cc85c --- /dev/null +++ b/tests/test_consts.py @@ -0,0 +1,8 @@ +from blocksync._consts import ByteSizes + + +def test_parse_readable_byte_size(): + assert ByteSizes.MiB == ByteSizes.parse_readable_byte_size("1_048_576") + + assert ByteSizes.MiB == ByteSizes.parse_readable_byte_size("1M") + assert ByteSizes.M == ByteSizes.parse_readable_byte_size("1M") diff --git a/tests/test_hooks.py b/tests/test_hooks.py index 71c9cf7..756abb3 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -2,53 +2,30 @@ import pytest -from blocksync.hooks import Hooks -from blocksync.status import Status +from blocksync._hooks import Hooks @pytest.fixture -def stub_status(): - return Status() +def stub_hooks(): + return Hooks(Mock(), Mock(), Mock(), Mock()) -def test_run_root_before(): - hook = Hooks() - hook.root_before = Mock() - hook.run_root_before() - hook.root_before.assert_called_once() +def test_run_before(stub_hooks): + stub_hooks.run_before() + stub_hooks.before.assert_called_once() -def test_run_before(): - hook = Hooks() - hook.before = Mock() - hook.run_before() - hook.before.assert_called_once() +def test_run_after(stub_hooks, fake_status): + stub_hooks.run_after(fake_status) + stub_hooks.after.assert_called_once_with(fake_status) -def test_run_root_after(stub_status): - hook = Hooks() - hook.root_after = Mock() - hook.run_root_after(stub_status) - hook.root_after.assert_called_once_with(stub_status) +def test_run_monitor(stub_hooks, fake_status): + stub_hooks.run_monitor(fake_status) + stub_hooks.monitor.assert_called_once_with(fake_status) -def test_run_after(stub_status): - hook = Hooks() - hook.after = Mock() - hook.run_after(stub_status) - hook.after.assert_called_once_with(stub_status) - - -def test_run_monitor(stub_status): - hook = Hooks() - hook.monitor = Mock() - hook.run_monitor(stub_status) - hook.monitor.assert_called_once_with(stub_status) - - -def test_run_on_error(stub_status): - hook = Hooks() - hook.on_error = Mock() +def test_run_on_error(stub_hooks, fake_status): exc = Exception() - hook.run_on_error(exc, stub_status) - hook.on_error.assert_called_once_with(exc, stub_status) + stub_hooks.run_on_error(exc, fake_status) + stub_hooks.on_error.assert_called_once_with(exc, fake_status) diff --git a/tests/test_local_file.py b/tests/test_local_file.py deleted file mode 100644 index 067eaec..0000000 --- a/tests/test_local_file.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -import unittest.mock -from pathlib import Path - -import pytest - -from blocksync import LocalFile - - -@pytest.fixture -def stub_local_file(tmp_path): - path = tmp_path / "local.file" - file = LocalFile(path) - return file.do_create(10) - - -def get_file_size(path: Path): - file = open(path, "r") - file.seek(os.SEEK_SET, os.SEEK_END) - return file.tell() - - -def test_do_create(tmp_path): - # Expect: Create a file of a specific size - path = tmp_path / "local.file" - file = LocalFile(path) - file.do_create(10) - assert path.exists() - assert get_file_size(path) == 10 - - -def test_do_open(stub_local_file): - # When: Create file and open it - stub_local_file.do_open() - - # Then: File opened for reading and writing in binary mode - assert stub_local_file.io.mode == "rb+" - - # And: File positioning is file's start - assert stub_local_file.io.tell() == 0 - - # And: Set file.size to actual file size - assert stub_local_file.size == get_file_size(stub_local_file.path) - - -def test_do_close(stub_local_file): - # Expect: File closed - stub_local_file.do_open().do_close() - assert stub_local_file.io.closed - with pytest.raises(ValueError, match="closed file"): - stub_local_file.io.write(b"test") - - # Expect: Don't flush - io = stub_local_file.io - io.flush = unittest.mock.Mock() - stub_local_file.do_open().do_close(flush=False) - io.flush.assert_not_called() - - -def test_get_block(stub_local_file): - # Expect: Read file - stub_local_file.do_open().io.write(b"1234567890") - stub_local_file.io.seek(0) - assert stub_local_file.get_block() == b"1234567890" - - # Expect: Read specific size from file's current position - stub_local_file.io.seek(0) - assert stub_local_file.get_block(5) == b"12345" - - -def test_get_blocks(stub_local_file): - # Expect: Read file blocks separated by specific size - stub_local_file.do_open().io.write(b"1234567890") - stub_local_file.io.seek(0) - assert list(stub_local_file.get_blocks(5)) == [b"12345", b"67890"] - - -def test_io_property(stub_local_file): - # Expect: Set io after file open - assert stub_local_file.do_open().io is not None - - # Expect: io is None when file not opened - assert LocalFile("a.file").io is None - - -def test_opened_property(stub_local_file): - # Expect: Return True when file opened - stub_local_file.do_open() - assert stub_local_file.opened - - # Expect: Return False when file closed - stub_local_file.do_close() - assert not stub_local_file.opened - - # Expect: Return False when file not opened - assert not LocalFile("a.file").opened diff --git a/tests/test_read_server.py b/tests/test_read_server.py new file mode 100644 index 0000000..c460168 --- /dev/null +++ b/tests/test_read_server.py @@ -0,0 +1,24 @@ +import subprocess +from hashlib import sha256 + +from blocksync._consts import BASE_DIR + + +def test_read_server(source_file, source_content, pytester): + p = pytester.popen( + ["python", (BASE_DIR / "_read_server.py")], + bufsize=0, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + stdin, stdout = p.stdin, p.stdout + stdin.write(f"{source_file}\n".encode()) + assert int(stdout.readline()) == len(source_content) + + stdin.write(f"{len(source_content)}\nsha256\n0\n1\n".encode()) + hashed = sha256(source_content) + digest = stdout.read(hashed.digest_size) + assert digest == hashed.digest() + + stdin.write(b"2") + assert stdout.read(len(source_content)) == source_content diff --git a/tests/test_sftp_file.py b/tests/test_sftp_file.py deleted file mode 100644 index 7532f3e..0000000 --- a/tests/test_sftp_file.py +++ /dev/null @@ -1,91 +0,0 @@ -from pathlib import Path -from unittest.mock import Mock - -import paramiko -import pytest - -from blocksync import SFTPFile - - -@pytest.fixture(autouse=True) -def stub_ssh_client(mocker): - ssh_client = mocker.patch( - "blocksync.files.sftp_file.paramiko.SSHClient", - Mock(get_transport=Mock(return_value=Mock(is_active=Mock(return_value=True)))), - ) - return ssh_client.return_value - - -@pytest.fixture -def stub_sftp_client(stub_ssh_client): - return stub_ssh_client.open_sftp.return_value - - -def test_target_file_is_a_block_device(mocker, stub_ssh_client, stub_sftp_client): - # When: The target file is a block device - stub_ssh_client.exec_command.return_value = (Mock(), Mock(read=Mock(return_value=b"10")), Mock()) - mocker.patch("blocksync.files.sftp_file.File._get_size", return_value=0) - stub_sftp_client.open.return_value.stat.return_value.st_mode = 25008 - file = SFTPFile("filepath") - - # Then: Use a special command to get the size of the block device - assert file.do_open().size == 10 - stub_ssh_client.exec_command.assert_called_once_with( - """python -c "with open('filepath', 'r') as f: f.seek(0, 2); print(f.tell())" """ - ) - - # And: Return if already got the size - mocker.patch("blocksync.files.sftp_file.File._get_size", return_value=1) - assert file.do_open().size == 1 - - -def test_do_close(mocker, stub_ssh_client): - # Expect: Close only when ssh connected - mock_do_close = mocker.patch("blocksync.files.sftp_file.File.do_close") - file = SFTPFile("") - file.do_close() - stub_ssh_client.get_transport.return_value.is_active.return_value = False - file.do_close() - mock_do_close.assert_called_once_with(True) - - -def test_open(stub_ssh_client): - # Expect: Open without error even if path argument is an instance of pathlib.Path - file = SFTPFile(Path("test")) - file._open("r") - stub_ssh_client.open_sftp.return_value.open.assert_called_once_with("test", mode="r") - - -def test_raise_error_when_ssh_not_connected(stub_ssh_client): - # Expect: Raise error when ssh not connected - file = SFTPFile("") - stub_ssh_client.get_transport.return_value.is_active.return_value = False - with pytest.raises(ValueError, match="Cannot open the remote file. Please connect paramiko.SSHClient"): - file._open("r") - - -def test_call_setup_methods(stub_ssh_client): - # Expect: Call paramiko setup methods - SFTPFile("") - stub_ssh_client.set_missing_host_key_policy.assert_called_once_with(paramiko.AutoAddPolicy) - stub_ssh_client.load_system_host_keys.assert_called_once() - - -def test_pass_ssh_config(stub_ssh_client): - # Expect: Connect SSH using passed arguments - SFTPFile("", hostname="test", password="test") - stub_ssh_client.connect.assert_called_once_with(hostname="test", password="test") - - -def test_ssh_connected(stub_ssh_client): - # Expect: Return True when SSH connected - file = SFTPFile("") - assert file.ssh_connected - - # Expect: Return False when SSH not connected - stub_ssh_client.get_transport.return_value.is_active.return_value = False - assert not file.ssh_connected - - # Expect: Return False when SSH hasn't transport - stub_ssh_client.get_transport.return_value = None - assert not file.ssh_connected diff --git a/tests/test_status.py b/tests/test_status.py index bd20ff5..9da1ed6 100644 --- a/tests/test_status.py +++ b/tests/test_status.py @@ -1,56 +1,43 @@ -from blocksync.consts import ByteSizes -from blocksync.status import Blocks, Status - - -def test_initialize_status(): - # Expect: Initialize all status - status = Status() - status.initialize( - block_size=1_000, - source_size=1_000, - destination_size=1_000, - ) - status._add_block("same") - status._add_block("diff") - status.initialize(block_size=10, source_size=10, destination_size=10) - assert status.block_size == 10 - assert status.source_size == 10 - assert status.destination_size == 10 - assert status.blocks == Blocks(same=0, diff=0, done=0) - - -def test_add_block(): +from blocksync._consts import ByteSizes +from blocksync._status import Blocks + + +def test_initialize_status(fake_status): + # Expect: Set chunk size + assert fake_status.chunk_size == fake_status.src_size // fake_status.workers + + +def test_add_block(fake_status): # Expect: Add each blocks and calculate done block - status = Status() - status._add_block("same") - status._add_block("same") - status._add_block("diff") - assert status.blocks == Blocks(same=2, diff=1, done=3) + fake_status.add_block("same") + fake_status.add_block("same") + fake_status.add_block("diff") + assert fake_status.blocks == Blocks(same=2, diff=1, done=3) -def test_get_rate(): +def test_get_rate(fake_status): # Expect: Return 0.00 when nothing done - status = Status() - assert status.rate == 0.00 + assert fake_status.rate == 0.00 - status.initialize(source_size=ByteSizes.MiB * 10, destination_size=ByteSizes.MiB * 10) + fake_status.block_size = ByteSizes.MiB + fake_status.src_size = fake_status.dest_size = ByteSizes.MiB * 10 # Expect: Return 50.00 when half done - status._add_block("same") - status._add_block("same") - status._add_block("same") - status._add_block("diff") - status._add_block("diff") - assert status.rate == 50.00 + fake_status.add_block("same") + fake_status.add_block("same") + fake_status.add_block("same") + fake_status.add_block("diff") + fake_status.add_block("diff") + assert fake_status.rate == 50.00 # Expect: Return 100.00 when all done - status._add_block("same") - status._add_block("same") - status._add_block("same") - status._add_block("diff") - status._add_block("diff") - assert status.rate == 100.00 + fake_status.add_block("same") + fake_status.add_block("same") + fake_status.add_block("same") + fake_status.add_block("diff") + fake_status.add_block("diff") + assert fake_status.rate == 100.00 # Expect: Return 100.00 when exceeding the total size - status._add_block("diff") - assert status.rate == 100.00 + fake_status.add_block("diff") + assert fake_status.rate == 100.00 diff --git a/tests/test_sync.py b/tests/test_sync.py new file mode 100644 index 0000000..07c5ddc --- /dev/null +++ b/tests/test_sync.py @@ -0,0 +1,82 @@ +from unittest.mock import Mock + +import paramiko + +from blocksync.sync import ( + _connect_ssh, + _do_create, + _get_block_size, + _get_blocks, + _get_range, + _get_remotedev_size, + _get_size, + _log, +) + + +def test_get_block_size(): + assert _get_block_size(1) == 1 + assert _get_block_size("1B") == 1 + + +def test_get_range(fake_status): + fake_status.src_size += 1 + assert _get_range(1, fake_status) == (0, 1) + assert _get_range(2, fake_status) == (500, 2) + + +def test_get_size(source_file, source_content): + assert _get_size(str(source_file)) == len(source_content) + + +def test_remotedev_size(): + stub_stdin = Mock() + stub_stdout = Mock(readline=Mock(return_value=10)) + stub_ssh_client = Mock(exec_command=Mock(return_value=(stub_stdin, stub_stdout, Mock()))) + assert 10 == _get_remotedev_size(stub_ssh_client, "command", "path") + stub_stdin.write.assert_called_once_with("path\n") + stub_stdout.readline.assert_called_once() + stub_stdin.close.assert_called_once() + stub_stdout.close.assert_called_once() + + +def test_do_create(pytester): + path = pytester.path / "new.file" + _do_create(str(path), 10) + assert path.exists() + assert _get_size(str(path)) == 10 + + +def test_get_blocks(pytester): + path = pytester.path / "new.file" + _do_create(str(path), 20) + file_data = path.read_bytes() + with open(path, "rb") as f: + read_data = list(_get_blocks(f, 10)) + assert read_data[0] == file_data[:10] + assert read_data[1] == file_data[10:] + + +def test_log(mocker): + mock_logger = mocker.patch("blocksync.sync.logger") + _log(1, "test", 10) + mock_logger.log(10, f"[Worker {1}] test") + + +def test_connect_ssh(mocker): + mock_ssh_client = mocker.patch("blocksync.sync.paramiko.SSHClient")() + _connect_ssh( + hostname="hostname", + password="password", + ) + mock_ssh_client.set_missing_host_key_policy.assert_called_once_with(paramiko.AutoAddPolicy) + mock_ssh_client.load_system_host_keys.assert_called_once() + mock_ssh_client.connect.assert_called_once_with( + hostname="hostname", + password="password", + compress=True, + ) + + mock_ssh_client.reset_mock() + _connect_ssh(allow_load_system_host_keys=False) + mock_ssh_client.load_system_host_keys.assert_not_called() diff --git a/tests/test_sync_manager.py b/tests/test_sync_manager.py new file mode 100644 index 0000000..7c0e610 --- /dev/null +++ b/tests/test_sync_manager.py @@ -0,0 +1,45 @@ +from unittest.mock import Mock + +from blocksync._sync_manager import SyncManager + + +def test_cancel_sync(): + manager = SyncManager() + assert not manager.canceled + + manager.cancel_sync() + assert manager.canceled + + +def test_wait_sync(): + mock_worker1, mock_worker2 = Mock(), Mock() + manager = SyncManager() + manager.workers.append(mock_worker1) + manager.workers.append(mock_worker2) + manager.wait_sync() + mock_worker1.join.assert_called_once() + mock_worker2.join.assert_called_once() + + +def test_suspend_and_resume(): + manager = SyncManager() + manager.suspend() + assert manager.suspended + + manager.resume() + assert not manager.suspended + + manager._suspend = Mock() + manager._wait_resuming() + manager._suspend.wait.assert_called_once() + + +def test_finished(): + worker = Mock() + worker.is_alive.return_value = False + manager = SyncManager() + manager.workers.append(worker) + assert manager.finished + + worker.is_alive.return_value = True + assert not manager.finished diff --git a/tests/test_syncer.py b/tests/test_syncer.py deleted file mode 100644 index 527d4c6..0000000 --- a/tests/test_syncer.py +++ /dev/null @@ -1,165 +0,0 @@ -from unittest.mock import Mock - -import pytest - -from blocksync.syncer import Syncer - - -@pytest.fixture(autouse=True) -def mock_worker(mocker): - return mocker.patch("blocksync.syncer.Worker") - - -@pytest.fixture -def mock_syncer(): - syncer = Syncer(Mock(), Mock()) - syncer.src.size = syncer.dest.size = 10 - syncer._suspended = Mock() - syncer.status = Mock() - syncer.hooks = Mock() - syncer.logger = Mock() - return syncer - - -def test_raise_error_if_worker_less_than_1(mock_syncer): - with pytest.raises(ValueError, match="Workers must be greater than 1"): - mock_syncer.start_sync(workers=0) - - -def test_pre_sync(mock_syncer): - # Expect: Open both of target files - mock_syncer.src.size = mock_syncer.dest.size = 10 - mock_syncer._pre_sync() - mock_syncer.src.do_open.assert_called_once() - mock_syncer.dest.do_open.assert_called_once() - - -def test_when_destination_file_does_not_exists(mock_syncer): - # Expect: Raise error if create is False - mock_syncer.dest.do_open.side_effect = FileNotFoundError - with pytest.raises(FileNotFoundError): - mock_syncer._pre_sync(create=False) - - # Expect: Create an empty file with the size of the source - mock_syncer._pre_sync(create=True) - mock_syncer.dest.do_create.assert_called_once_with(10) - mock_syncer.dest.do_create.return_value.do_open.assert_called_once() - - -def test_when_src_and_dest_file_size_does_not_same(mock_syncer): - # Expect: Call info with expected msg - mock_syncer.src.size = mock_syncer.dest.size - 1 - mock_syncer._pre_sync() - mock_syncer.logger.info.assert_called_once_with("Source size(9) is less than destination size(10)") - - # Expect: Call warning with expected msg - mock_syncer.src.size = mock_syncer.dest.size + 1 - mock_syncer._pre_sync() - mock_syncer.logger.warning.assert_called_once_with("Source size(11) is greater than destination size(10)") - - -def test_get_positions(mock_syncer): - # Expect: Return the correct positions - assert mock_syncer._get_positions(workers=1, worker_id=1) == (0, 10) - assert mock_syncer._get_positions(workers=2, worker_id=1) == (0, 5) - assert mock_syncer._get_positions(workers=2, worker_id=2) == (5, 10) - assert mock_syncer._get_positions(workers=3, worker_id=1) == (0, 3) - assert mock_syncer._get_positions(workers=3, worker_id=2) == (3, 6) - assert mock_syncer._get_positions(workers=3, worker_id=3) == (6, 10) - - -def test_start_sync(mock_syncer, mock_worker): - mock_syncer._pre_sync = Mock() - mock_syncer.wait = Mock() - mock_syncer.start_sync( - workers=1, - block_size=1, - wait=True, - dryrun=False, - create=False, - sync_interval=1, - monitoring_interval=1, - ) - assert not mock_syncer.canceled - mock_syncer._pre_sync.assert_called_once_with(False) - mock_syncer.status.initialize.assert_called_once_with(block_size=1, source_size=10, destination_size=10) - mock_syncer._suspended.set.assert_called_once() - worker_instance = mock_worker.return_value - mock_worker.assert_called_once_with( - worker_id=1, - syncer=mock_syncer, - startpos=0, - endpos=10, - dryrun=False, - sync_interval=1, - monitoring_interval=1, - logger=mock_syncer.logger, - ) - mock_worker.return_value.start.assert_called_once() - mock_syncer.wait.assert_called_once() - assert mock_syncer.workers == [worker_instance] - assert mock_syncer.started - - # Expect: Create worker instances by number of workers - mock_syncer.start_sync( - workers=3, - block_size=1, - wait=True, - dryrun=False, - create=False, - sync_interval=1, - monitoring_interval=1, - ) - assert mock_syncer.workers == [worker_instance, worker_instance, worker_instance] - - -def test_finished(mock_syncer, mock_worker): - # Expect: Return False before start - assert not mock_syncer.finished - mock_syncer.start_sync() - - # Expect: Return False if some worker is alive - mock_worker.return_value.is_alive.return_value = True - assert not mock_syncer.finished - - # Expect: Return True if all worker finished - mock_worker.return_value.is_alive.return_value = False - assert mock_syncer.finished - - -def test_wait(mock_syncer, mock_worker): - # Expect: Wait all workers - mock_worker.return_value.is_alive.return_value = True - mock_syncer.start_sync(workers=2).wait() - assert mock_worker.return_value.is_alive.call_count == 2 - assert mock_worker.return_value.join.call_count == 2 - - -def test_cancel(mock_syncer): - assert mock_syncer.cancel()._canceled - - -def test_suspend(mock_syncer): - # Expect: Suspend when not blocking - mock_syncer._suspended.is_set.return_value = True - mock_syncer.suspend() - mock_syncer._suspended.clear.assert_called_once() - mock_syncer._suspended.reset_mock() - - mock_syncer._suspended.is_set.return_value = False - mock_syncer.suspend() - assert mock_syncer.suspended - mock_syncer._suspended.clear.assert_not_called() - - -def test_resume(mock_syncer): - # Expect: Resume when blocking - mock_syncer._suspended.is_set.return_value = False - mock_syncer.resume() - mock_syncer._suspended.set.assert_called_once() - mock_syncer._suspended.reset_mock() - - mock_syncer._suspended.is_set.return_value = True - mock_syncer.resume() - assert not mock_syncer.suspended - mock_syncer._suspended.set.assert_not_called() diff --git a/tests/test_worker.py b/tests/test_worker.py deleted file mode 100644 index 982c367..0000000 --- a/tests/test_worker.py +++ /dev/null @@ -1,105 +0,0 @@ -import logging -import os -from unittest.mock import Mock, call - -import pytest - -from blocksync.worker import Worker - - -@pytest.fixture(autouse=True) -def mock_time(mocker): - return mocker.patch("blocksync.worker.time") - - -@pytest.fixture -def stub_worker(mocker): - worker = Worker( - worker_id=1, - syncer=Mock(canceled=False, suspended=False), - startpos=0, - endpos=5, - dryrun=False, - sync_interval=1, - monitoring_interval=1, - logger=Mock(), - ) - mocker.patch("blocksync.worker.timer", side_effect=[0, 1, 1.5, 2, 3]) - worker.syncer.status.block_size = 2 - worker.syncer.src.io.tell.side_effect = [1, 2, 4, 5] # type: ignore[union-attr] - worker.syncer.src.get_blocks.return_value = [b"1", b"2", b"33", b"5"] # type: ignore[attr-defined] - worker.syncer.dest.get_blocks.return_value = [b"1", b"2", b"44", b"6"] # type: ignore[attr-defined] - return worker - - -def test_run(stub_worker): - stub_worker._sync = Mock() - stub_worker.run() - stub_worker.syncer.hooks.run_root_before.assert_called_once() - stub_worker.syncer.hooks.run_root_after.assert_called_once() - stub_worker._sync.assert_called_once() - stub_worker.syncer.src.do_close.assert_called_once() - stub_worker.syncer.dest.do_close.assert_called_once() - - -def test_sync(stub_worker, mock_time): - stub_worker._sync() - stub_worker.syncer.src.do_open.return_value.io.seek.assert_called_once_with(stub_worker.startpos) - stub_worker.syncer.dest.do_open.return_value.io.seek.assert_called_once_with(stub_worker.startpos) - stub_worker.syncer.hooks.run_before.assert_called_once() - - stub_worker.syncer.status._add_block.assert_has_calls([call("same"), call("same"), call("diff")]) - - stub_worker.syncer.dest.io.seek.assert_has_calls([call(-2, os.SEEK_CUR), call(-1, os.SEEK_CUR)]) - stub_worker.syncer.dest.io.write.assert_has_calls([call(b"33"), call(b"5")]) - assert stub_worker.syncer.dest.io.flush.call_count == 2 - - stub_worker.syncer.hooks.run_monitor.assert_has_calls( - [ - call(stub_worker.syncer.status), - call(stub_worker.syncer.status), - call(stub_worker.syncer.status), - ] - ) - - mock_time.sleep.assert_has_calls([call(1), call(1), call(1)]) - - stub_worker.logger.log.assert_has_calls( - [ - call( - logging.INFO, - f"[Worker 1]: Start sync(startpos: 0, endpos: 5) {stub_worker.syncer.src} to {stub_worker.syncer.dest}", - ), - call(logging.INFO, "[Worker 1]: !!! Done !!!"), - ] - ) - stub_worker.syncer.hooks.run_after.assert_called_once_with(stub_worker.syncer.status) - - -def test_sync_with_suspending(stub_worker): - stub_worker.syncer.suspended = True - stub_worker._sync() - stub_worker.logger.log.assert_has_calls([call(logging.INFO, "[Worker 1]: Suspending...")]) - stub_worker.syncer._suspended.wait.assert_called() - - -def test_sync_with_canceling(stub_worker): - stub_worker.syncer.canceled = True - stub_worker._sync() - stub_worker.logger.log.assert_has_calls([call(logging.INFO, "[Worker 1]: Synchronization task has been canceled")]) - - -def test_run_on_error(stub_worker): - expected_error = OSError("Error raised when writing data") - stub_worker.syncer.dest.io.write.side_effect = expected_error - stub_worker._sync() - stub_worker.syncer.hooks.run_on_error.assert_called_once_with(expected_error, stub_worker.syncer.status) - stub_worker.logger.log.assert_has_calls( - [call(logging.ERROR, "[Worker 1]: Error raised when writing data", exc_info=True)] - ) - - -def test_log(stub_worker): - # Expect: Call logger.log with message, level, additional arguments - stub_worker._log("test", logging.DEBUG, 1, exc_info=True) - stub_worker.logger.log.assert_called_once_with(logging.DEBUG, "[Worker 1]: test", 1, exc_info=True) diff --git a/tests/test_write_server.py b/tests/test_write_server.py new file mode 100644 index 0000000..54bac84 --- /dev/null +++ b/tests/test_write_server.py @@ -0,0 +1,21 @@ +import subprocess + +from blocksync._consts import BASE_DIR + + +def test_write_server(pytester): + p = pytester.popen( + ["python", (BASE_DIR / "_write_server.py")], + bufsize=0, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + stdin = p.stdin + dest_file_path = str(pytester.path / "dest.img") + expected_dest_file_content = b"a" * 20 + stdin.write(f"{dest_file_path}\n20\n20\n0\n1\n".encode()) + stdin.write(b"2") + stdin.write(expected_dest_file_content) + p.wait() + dest_file = open(dest_file_path, "rb") + assert dest_file.read() == expected_dest_file_content