#!/usr/bin/env python
"""
Shifted Counter Generators for HDC
HDC-compatible wrapper for counter-based random number generation with
invertible mappings (Feistel, SPN, ARX, etc.).
"""
import random
from typing import Any, Callable, Dict, List, Optional
from pyhdc.generation.base import HDCGenerator
class ShiftedCounterGenerator(HDCGenerator):
"""
Base class for shifted counter generators with invertible mappings.
Combines a simple counter with cryptographic-style permutations
for pseudorandom number generation.
"""
def __init__(
self, bit_width: int = 32, shift_amount: int = 13, seed: Optional[int] = None
) -> None:
"""
Initialize shifted counter generator.
Args:
bit_width: Width of the counter in bits
shift_amount: Amount to shift the counter
seed: Optional seed for reproducibility
"""
self._bit_width = bit_width
self._shift_amount = shift_amount % bit_width if bit_width > 0 else 0
self._validate_parameters()
super().__init__(seed)
def _validate_parameters(self) -> None:
"""Validate counter parameters."""
if not isinstance(self._bit_width, int) or self._bit_width <= 0:
raise ValueError("Bit width must be a positive integer")
if not isinstance(self._shift_amount, int) or self._shift_amount < 0:
raise ValueError("Shift amount must be a non-negative integer")
def _configure_internal(self) -> None:
"""Configure the counter state."""
self._max_value = (1 << self._bit_width) - 1
if self._seed is None:
self._counter = random.randint(0, self._max_value)
else:
if (
not isinstance(self._seed, int)
or self._seed < 0
or self._seed > self._max_value
):
raise ValueError(f"Seed must be between 0 and {self._max_value}")
self._counter = self._seed
self._initial_counter = self._counter
def _shift_counter(self) -> int:
"""Apply shift operation to current counter value."""
shifted = (
(self._counter << self._shift_amount)
| (self._counter >> (self._bit_width - self._shift_amount))
) & self._max_value
return shifted
def _apply_mapping(self, value: int) -> int:
"""Apply invertible mapping (implemented by subclasses)."""
raise NotImplementedError("Subclasses must implement _apply_mapping")
def _next_value(self) -> int:
"""Generate next value."""
shifted_value = self._shift_counter()
output = self._apply_mapping(shifted_value)
self._counter = (self._counter + 1) & self._max_value
return output
def _next_bit(self) -> int:
"""Generate next bit."""
return self._next_value() & 1
def _next_word(self, word_size: int = 32) -> int:
"""Generate next word."""
if word_size <= self._bit_width:
return self._next_value() & ((1 << word_size) - 1)
else:
# Combine multiple outputs
result = 0
bits_collected = 0
while bits_collected < word_size:
value = self._next_value()
result |= value << bits_collected
bits_collected += self._bit_width
return result & ((1 << word_size) - 1)
def set_parameters(
self,
bit_width: Optional[int] = None,
shift_amount: Optional[int] = None,
seed: Optional[int] = None,
) -> None:
"""Set counter parameters."""
if bit_width is not None:
self._bit_width = bit_width
self._shift_amount = self._shift_amount % bit_width if bit_width > 0 else 0
if shift_amount is not None:
self._shift_amount = (
shift_amount % self._bit_width if self._bit_width > 0 else 0
)
if seed is not None:
self._seed = seed
self._validate_parameters()
self._configure_internal()
def set_bit_width(self, bit_width: int) -> None:
"""Set the bit width parameter."""
self.set_parameters(bit_width=bit_width)
def set_shift_amount(self, shift_amount: int) -> None:
"""Set the shift amount parameter."""
self.set_parameters(shift_amount=shift_amount)
def get_parameters(self) -> Dict[str, Any]:
"""Get current parameters."""
return {
"bit_width": self._bit_width,
"shift_amount": self._shift_amount,
"seed": self._seed,
}
def get_bit_width(self) -> int:
"""Get the bit width parameter."""
return self._bit_width
def get_shift_amount(self) -> int:
"""Get the shift amount parameter."""
return self._shift_amount
def get_period(self) -> int:
"""Get the period length."""
return 1 << self._bit_width
def reset(self) -> None:
"""Reset to initial state."""
self._counter = self._initial_counter
def get_state(self) -> int:
"""Get current counter state."""
return self._counter
[docs]
class FeistelCounterGenerator(ShiftedCounterGenerator):
"""
Shifted counter with Feistel network mapping.
Uses a Feistel network to create an invertible permutation.
"""
def __init__(
self,
bit_width: int = 32,
shift_amount: int = 13,
rounds: int = 4,
round_keys: Optional[List[int]] = None,
seed: Optional[int] = None,
) -> None:
"""Initialize Feistel counter."""
if bit_width % 2 != 0:
raise ValueError("Bit width must be even for Feistel networks")
if not isinstance(rounds, int) or rounds <= 0:
raise ValueError("Rounds must be a positive integer")
self._rounds = rounds
self._round_keys = round_keys
super().__init__(bit_width, shift_amount, seed)
def _configure_internal(self) -> None:
"""Configure with Feistel-specific parameters."""
super()._configure_internal()
self._half_width = self._bit_width // 2
self._half_mask = (1 << self._half_width) - 1
if self._round_keys is None:
# Use a seeded local RNG so round keys are deterministic when a seed is set
rng = random.Random(self._seed)
self._round_keys = [
rng.randint(0, self._half_mask) for _ in range(self._rounds)
]
else:
if len(self._round_keys) != self._rounds:
raise ValueError(f"Must provide {self._rounds} round keys")
self._round_keys = [k & self._half_mask for k in self._round_keys]
def _feistel_function(self, half_block: int, round_key: int) -> int:
"""Feistel function (F-function)."""
result = half_block ^ round_key
# Add non-linearity with rotations and XOR
result = ((result << 3) | (result >> (self._half_width - 3))) & self._half_mask
result ^= result >> 1
result = ((result << 1) | (result >> (self._half_width - 1))) & self._half_mask
return result
def _apply_mapping(self, value: int) -> int:
"""Apply Feistel network mapping."""
left = (value >> self._half_width) & self._half_mask
right = value & self._half_mask
for i in range(self._rounds):
new_left = right
new_right = left ^ self._feistel_function(right, self._round_keys[i])
new_right &= self._half_mask
left, right = new_left, new_right
return ((left << self._half_width) | right) & self._max_value
[docs]
def set_round_keys(self, round_keys: List[int]) -> None:
"""Set the round keys."""
if len(round_keys) != self._rounds:
raise ValueError(f"Must provide {self._rounds} round keys")
self._round_keys = [k & self._half_mask for k in round_keys]
[docs]
def get_round_keys(self) -> List[int]:
"""Get the round keys."""
return self._round_keys.copy()
[docs]
class ARXCounterGenerator(ShiftedCounterGenerator):
"""
Shifted counter with Addition-Rotation-XOR mapping.
Uses ARX operations for fast, secure mixing.
"""
def __init__(
self,
bit_width: int = 32,
shift_amount: int = 11,
constants: Optional[List[int]] = None,
rotations: Optional[List[int]] = None,
seed: Optional[int] = None,
) -> None:
"""Initialize ARX counter."""
self._constants = constants
self._rotations = rotations
super().__init__(bit_width, shift_amount, seed)
def _configure_internal(self) -> None:
"""Configure with ARX-specific parameters."""
super()._configure_internal()
if self._constants is None:
# Default constants (similar to ChaCha/Salsa)
self._constants = [0x61707865, 0x3320646E, 0x79622D32, 0x6B206574]
self._constants = [c & self._max_value for c in self._constants]
else:
self._constants = [c & self._max_value for c in self._constants]
if self._rotations is None:
self._rotations = (
[7, 12, 17, 22] if self._bit_width >= 32 else [3, 5, 7, 11]
)
def _rotleft(self, value: int, amount: int) -> int:
"""Rotate value left."""
amount %= self._bit_width
return (
(value << amount) | (value >> (self._bit_width - amount))
) & self._max_value
def _apply_mapping(self, value: int) -> int:
"""Apply ARX mapping."""
result = value
for i, (constant, rotation) in enumerate(zip(self._constants, self._rotations)):
# Addition
result = (result + constant) & self._max_value
# Rotation
result = self._rotleft(result, rotation)
# XOR
xor_val = (constant << (i + 1)) & self._max_value
result ^= xor_val
return result
[docs]
def set_constants(self, constants: List[int]) -> None:
"""Set the ARX constants."""
self._constants = [c & self._max_value for c in constants]
[docs]
def set_rotations(self, rotations: List[int]) -> None:
"""Set the rotation amounts."""
self._rotations = rotations
[docs]
def get_constants(self) -> List[int]:
"""Get the ARX constants."""
return self._constants.copy()
[docs]
def get_rotations(self) -> List[int]:
"""Get the rotation amounts."""
return self._rotations.copy()
[docs]
class SPNCounterGenerator(ShiftedCounterGenerator):
"""
Shifted counter with Substitution-Permutation Network mapping.
Uses S-boxes and permutation for invertible mapping.
"""
def __init__(
self,
bit_width: int = 32,
shift_amount: int = 17,
sbox_size: int = 4,
sbox: Optional[List[int]] = None,
pbox: Optional[List[int]] = None,
seed: Optional[int] = None,
) -> None:
"""Initialize SPN counter."""
if bit_width % sbox_size != 0:
raise ValueError(f"Bit width must be divisible by S-box size ({sbox_size})")
self._sbox_size = sbox_size
self._sbox = sbox
self._pbox = pbox
super().__init__(bit_width, shift_amount, seed)
def _configure_internal(self) -> None:
"""Configure with SPN-specific parameters."""
super()._configure_internal()
self._sbox_count = self._bit_width // self._sbox_size
self._sbox_mask = (1 << self._sbox_size) - 1
if self._sbox is None:
# Generate random permutation for S-box
sbox_values = list(range(1 << self._sbox_size))
random.shuffle(sbox_values)
self._sbox = sbox_values
else:
if len(self._sbox) != (1 << self._sbox_size):
raise ValueError(f"S-box must have {1 << self._sbox_size} entries")
if sorted(self._sbox) != list(range(1 << self._sbox_size)):
raise ValueError("S-box must be a valid permutation")
if self._pbox is None:
# Generate random bit permutation
self._pbox = list(range(self._bit_width))
random.shuffle(self._pbox)
else:
if len(self._pbox) != self._bit_width or sorted(self._pbox) != list(
range(self._bit_width)
):
raise ValueError("P-box must be a valid bit permutation")
def _apply_sbox(self, value: int) -> int:
"""Apply S-box substitution."""
result = 0
for i in range(self._sbox_count):
sbox_input = (value >> (i * self._sbox_size)) & self._sbox_mask
sbox_output = self._sbox[sbox_input]
result |= sbox_output << (i * self._sbox_size)
return result & self._max_value
def _apply_pbox(self, value: int) -> int:
"""Apply P-box permutation."""
result = 0
for i in range(self._bit_width):
if (value >> i) & 1:
result |= 1 << self._pbox[i]
return result
def _apply_mapping(self, value: int) -> int:
"""Apply SPN mapping."""
substituted = self._apply_sbox(value)
permuted = self._apply_pbox(substituted)
return permuted
[docs]
def set_sbox(self, sbox: List[int]) -> None:
"""Set the S-box."""
if len(sbox) != (1 << self._sbox_size):
raise ValueError(f"S-box must have {1 << self._sbox_size} entries")
if sorted(sbox) != list(range(1 << self._sbox_size)):
raise ValueError("S-box must be a valid permutation")
self._sbox = sbox.copy()
[docs]
def set_pbox(self, pbox: List[int]) -> None:
"""Set the P-box."""
if len(pbox) != self._bit_width or sorted(pbox) != list(range(self._bit_width)):
raise ValueError("P-box must be a valid bit permutation")
self._pbox = pbox.copy()
[docs]
def get_sbox(self) -> List[int]:
"""Get the S-box."""
return self._sbox.copy()
[docs]
def get_pbox(self) -> List[int]:
"""Get the P-box."""
return self._pbox.copy()
[docs]
class CustomMappingCounterGenerator(ShiftedCounterGenerator):
"""
Shifted counter with custom user-defined mapping.
Allows users to provide their own invertible mapping function.
"""
def __init__(
self,
bit_width: int = 32,
shift_amount: int = 15,
mapping_func: Optional[Callable[[int], int]] = None,
seed: Optional[int] = None,
) -> None:
"""Initialize custom mapping counter."""
self._mapping_func = mapping_func
super().__init__(bit_width, shift_amount, seed)
def _configure_internal(self) -> None:
"""Configure with default mapping if none provided."""
super()._configure_internal()
if self._mapping_func is None:
# Default simple hash function
def simple_hash(x):
x ^= x >> 16
x *= 0x85EBCA6B
x ^= x >> 13
x *= 0xC2B2AE35
x ^= x >> 16
return x
self._mapping_func = simple_hash
def _apply_mapping(self, value: int) -> int:
"""Apply custom mapping."""
return self._mapping_func(value) & self._max_value
[docs]
def set_mapping_function(self, mapping_func: Callable[[int], int]) -> None:
"""Set a new mapping function."""
self._mapping_func = mapping_func
# Predefined counter configurations
[docs]
class CommonCounterGenerators:
"""Factory for common counter configurations."""
@staticmethod
def feistel_32bit(seed: Optional[int] = None) -> FeistelCounterGenerator:
"""32-bit Feistel counter with standard parameters."""
return FeistelCounterGenerator(
bit_width=32, shift_amount=13, rounds=4, seed=seed
)
@staticmethod
def feistel_64bit(seed: Optional[int] = None) -> FeistelCounterGenerator:
"""64-bit Feistel counter with standard parameters."""
return FeistelCounterGenerator(
bit_width=64, shift_amount=21, rounds=6, seed=seed
)
@staticmethod
def arx_32bit(seed: Optional[int] = None) -> ARXCounterGenerator:
"""32-bit ARX counter."""
return ARXCounterGenerator(bit_width=32, shift_amount=11, seed=seed)
@staticmethod
def arx_64bit(seed: Optional[int] = None) -> ARXCounterGenerator:
"""64-bit ARX counter."""
return ARXCounterGenerator(bit_width=64, shift_amount=17, seed=seed)
@staticmethod
def spn_32bit(seed: Optional[int] = None) -> SPNCounterGenerator:
"""32-bit SPN counter with 4-bit S-boxes."""
return SPNCounterGenerator(
bit_width=32, shift_amount=17, sbox_size=4, seed=seed
)
@staticmethod
def simple_hash(
bit_width: int = 32, seed: Optional[int] = None
) -> CustomMappingCounterGenerator:
"""Counter with simple hash-based mapping."""
return CustomMappingCounterGenerator(
bit_width=bit_width, shift_amount=15, seed=seed
)
# Legacy alias for backward compatibility
ShiftedCounter = FeistelCounterGenerator