Files
ldpc/srctmp/decoder.rs
2026-06-01 09:13:24 +02:00

447 lines
12 KiB
Rust
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::code::LdpcCode;
use crate::graph::TannerGraph;
use crate::matrix::SparseMatrixGF2;
use crate::{BitVec, Gf2, Llr};
// Résultat
#[derive(Debug, Clone)]
pub enum DecoderResult {
Converged(BitVec),
MaxIterationsReached(BitVec),
Failure,
}
impl DecoderResult {
pub fn codeword(&self) -> Option<&BitVec> {
match self {
DecoderResult::Converged(c) | DecoderResult::MaxIterationsReached(c) => Some(c),
DecoderResult::Failure => None,
}
}
pub fn is_success(&self) -> bool {
matches!(self, DecoderResult::Converged(_))
}
}
// Configuration
#[derive(Debug, Clone)]
pub struct DecoderConfig {
pub max_iterations: usize,
pub early_stopping: bool,
}
impl Default for DecoderConfig {
fn default() -> Self {
Self {
max_iterations: 50,
early_stopping: true,
}
}
}
// Trait Decoder
pub trait Decoder: Send + Sync {
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult;
fn decode_hard(&self, received: &[Gf2]) -> DecoderResult {
let llr: Vec<Llr> = received
.iter()
.map(|&b| if b == 0 { 1.0 } else { -1.0 })
.collect();
self.decode(&llr)
}
}
// Primitives GF(2) et LLR
#[inline]
pub fn hard_decision(llr: Llr) -> Gf2 {
if llr >= 0.0 {
0
} else {
1
}
}
pub fn compute_syndrome(h: &SparseMatrixGF2, c: &[Gf2]) -> Vec<Gf2> {
h.multiply_vec(c)
}
// phi(x) = -ln(tanh(|x|/2)) involution de Sum-Product
// phi(phi(x)) = x
#[inline]
fn phi(x: Llr) -> Llr {
let ax = x.abs().max(1e-10);
-((ax / 2.0).tanh().ln())
}
// Mises à jour des nœuds
// Mise à jour Sum-Product du nœud de contrôle
// R_{c→v} = φ(∑_{v'≠v} φ(|L_{v'→c}|)) × sign(∏_{v'≠v} sign(L_{v'→c}))
fn check_node_update_sp(incoming: &[Llr], out: &mut [Llr]) {
let phi_sum: Llr = incoming.iter().map(|&l| phi(l.abs())).sum();
let sign_prod: Llr = incoming.iter().map(|&l| l.signum()).product();
for (j, (&l, r)) in incoming.iter().zip(out.iter_mut()).enumerate() {
let phi_excl = phi_sum - phi(l.abs());
let sign_excl = sign_prod * l.signum();
*r = sign_excl * phi(phi_excl);
}
}
// Mise à jour Min-Sum avec facteur de normalisation α
// Approxime φ(∑ φ(|L|)) ≈ min(|L|).
// alpha in [0.75, 0.875] compense le biais de Min-Sum brut
fn check_node_update_ms(incoming: &[Llr], out: &mut [Llr], alpha: Llr) {
let sign_prod: Llr = incoming.iter().map(|&l| l.signum()).product();
// Précalcul des deux plus petites valeurs absolues
let mut min1 = Llr::INFINITY;
let mut min2 = Llr::INFINITY;
let mut min1_idx = 0;
for (j, &l) in incoming.iter().enumerate() {
let al = l.abs();
if al < min1 {
min2 = min1;
min1 = al;
min1_idx = j;
} else if al < min2 {
min2 = al;
}
}
for (j, (&l, r)) in incoming.iter().zip(out.iter_mut()).enumerate() {
let min_excl = if j == min1_idx { min2 } else { min1 };
let sign_excl = sign_prod * l.signum();
*r = alpha * sign_excl * min_excl;
}
}
// Mise à jour du nœud de variable
// L_{v→c} = L_channel(v) + ∑_{c'≠c} R_{c'→v}
fn variable_node_update(ch_llr: Llr, incoming_c2v: &[Llr], out: &mut [Llr]) {
let total: Llr = ch_llr + incoming_c2v.iter().sum::<Llr>();
for ((&r, o)) in incoming_c2v.iter().zip(out.iter_mut()) {
*o = total - r;
}
}
#[inline]
fn posterior_llr(ch_llr: Llr, c2v_msgs: &[Llr]) -> Llr {
ch_llr + c2v_msgs.iter().sum::<Llr>()
}
// Messages internes
// Indexés par (check_id, position_dans_liste_voisins) accès O(1)
struct Messages {
v2c: Vec<Vec<Llr>>, // v2c[c][j] : var_neighbor(c)[j] -> check c
c2v: Vec<Vec<Llr>>, // c2v[c][j] : check c -> var_neighbor(c)[j]
}
impl Messages {
fn new(graph: &TannerGraph) -> Self {
let v2c = (0..graph.n_chk)
.map(|c| vec![0.0; graph.chk_degree(c)])
.collect();
let c2v = (0..graph.n_chk)
.map(|c| vec![0.0; graph.chk_degree(c)])
.collect();
Self { v2c, c2v }
}
}
// Table de correspondance, pour chaque (var, check), index dans la liste de voisins
// Précalculée une fois à la construction du décodeur
struct EdgeIndex {
// var_pos_in_chk[c][j] : position de var_neighbor(c)[j] dans var_to_chk[v]
var_pos_in_chk: Vec<Vec<usize>>,
// chk_pos_in_var[v][i] : position de var_neighbor(v)[i] dans chk_to_var[c]
chk_pos_in_var: Vec<Vec<usize>>,
}
impl EdgeIndex {
fn build(graph: &TannerGraph) -> Self {
let var_pos_in_chk = (0..graph.n_chk)
.map(|c| {
graph
.chk_neighbors(c)
.iter()
.map(|&v| graph.var_neighbors(v).iter().position(|&x| x == c).unwrap())
.collect()
})
.collect();
let chk_pos_in_var = (0..graph.n_var)
.map(|v| {
graph
.var_neighbors(v)
.iter()
.map(|&c| graph.chk_neighbors(c).iter().position(|&x| x == v).unwrap())
.collect()
})
.collect();
Self {
var_pos_in_chk,
chk_pos_in_var,
}
}
}
// Bit-Flipping
pub struct BitFlippingDecoder {
graph: TannerGraph,
h: SparseMatrixGF2,
config: DecoderConfig,
}
impl BitFlippingDecoder {
pub fn new(code: &LdpcCode, config: DecoderConfig) -> Self {
Self {
graph: code.graph.clone(),
h: code.h.clone(),
config,
}
}
}
impl Decoder for BitFlippingDecoder {
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
let mut bits: Vec<Gf2> = channel_llr.iter().map(|&l| hard_decision(l)).collect();
for _iter in 0..self.config.max_iterations {
let syndrome = compute_syndrome(&self.h, &bits);
if self.config.early_stopping && syndrome.iter().all(|&s| s == 0) {
return DecoderResult::Converged(bits);
}
let mut unsatisfied = vec![0usize; self.graph.n_var];
for c in 0..self.graph.n_chk {
if syndrome[c] == 1 {
for &v in self.graph.chk_neighbors(c) {
unsatisfied[v] += 1;
}
}
}
let mut flipped = false;
for v in 0..self.graph.n_var {
if unsatisfied[v] > self.graph.var_degree(v) / 2 {
bits[v] ^= 1;
flipped = true;
}
}
if !flipped {
break;
}
}
let synd = compute_syndrome(&self.h, &bits);
if synd.iter().all(|&s| s == 0) {
DecoderResult::Converged(bits)
} else {
DecoderResult::MaxIterationsReached(bits)
}
}
}
// Noyau BP partagé par SP et MinSum
fn bp_decode<F>(
graph: &TannerGraph,
h: &SparseMatrixGF2,
config: &DecoderConfig,
channel_llr: &[Llr],
edge_idx: &EdgeIndex,
check_update: F,
) -> DecoderResult
where
F: Fn(&[Llr], &mut [Llr]),
{
let mut msgs = Messages::new(graph);
// Initialisation : v2c[c][j] = L_channel(var_neighbor(c)[j])
for c in 0..graph.n_chk {
for (j, &v) in graph.chk_neighbors(c).iter().enumerate() {
msgs.v2c[c][j] = channel_llr[v];
}
}
for _iter in 0..config.max_iterations {
// Mise à jour des check-nodes
for c in 0..graph.n_chk {
let v2c = msgs.v2c[c].clone();
check_update(&v2c, &mut msgs.c2v[c]);
}
// Mise à jour des var-nodes
for v in 0..graph.n_var {
let neighbors = graph.var_neighbors(v);
// Rassembler les c2v entrants sur ce var-node
let incoming: Vec<Llr> = neighbors
.iter()
.enumerate()
.map(|(i, &c)| {
let j = edge_idx.chk_pos_in_var[v][i];
msgs.c2v[c][j]
})
.collect();
let mut new_v2c = vec![0.0; neighbors.len()];
variable_node_update(channel_llr[v], &incoming, &mut new_v2c);
for (i, &c) in neighbors.iter().enumerate() {
let j = edge_idx.chk_pos_in_var[v][i];
msgs.v2c[c][j] = new_v2c[i];
}
}
// Hard décision + arrêt
if config.early_stopping {
let bits = make_decision(graph, &msgs, channel_llr, edge_idx);
if compute_syndrome(h, &bits).iter().all(|&s| s == 0) {
return DecoderResult::Converged(bits);
}
}
}
let bits = make_decision(graph, &msgs, channel_llr, edge_idx);
let synd = compute_syndrome(h, &bits);
if synd.iter().all(|&s| s == 0) {
DecoderResult::Converged(bits)
} else {
DecoderResult::MaxIterationsReached(bits)
}
}
fn make_decision(
graph: &TannerGraph,
msgs: &Messages,
channel_llr: &[Llr],
edge_idx: &EdgeIndex,
) -> Vec<Gf2> {
(0..graph.n_var)
.map(|v| {
let incoming: Vec<Llr> = graph
.var_neighbors(v)
.iter()
.enumerate()
.map(|(i, &c)| {
let j = edge_idx.chk_pos_in_var[v][i];
msgs.c2v[c][j]
})
.collect();
hard_decision(posterior_llr(channel_llr[v], &incoming))
})
.collect()
}
// Sum-Product
pub struct SumProductDecoder {
graph: TannerGraph,
h: SparseMatrixGF2,
config: DecoderConfig,
edge_idx: EdgeIndex,
}
impl SumProductDecoder {
pub fn new(code: &LdpcCode, config: DecoderConfig) -> Self {
let edge_idx = EdgeIndex::build(&code.graph);
Self {
graph: code.graph.clone(),
h: code.h.clone(),
config,
edge_idx,
}
}
}
impl Decoder for SumProductDecoder {
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
bp_decode(
&self.graph,
&self.h,
&self.config,
channel_llr,
&self.edge_idx,
|incoming, out| check_node_update_sp(incoming, out),
)
}
}
// Min-Sum
pub struct MinSumDecoder {
graph: TannerGraph,
h: SparseMatrixGF2,
config: DecoderConfig,
scaling_factor: Llr,
edge_idx: EdgeIndex,
}
impl MinSumDecoder {
pub fn new(code: &LdpcCode, config: DecoderConfig, scaling_factor: Llr) -> Self {
let edge_idx = EdgeIndex::build(&code.graph);
Self {
graph: code.graph.clone(),
h: code.h.clone(),
config,
scaling_factor,
edge_idx,
}
}
}
impl Decoder for MinSumDecoder {
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
let alpha = self.scaling_factor;
bp_decode(
&self.graph,
&self.h,
&self.config,
channel_llr,
&self.edge_idx,
move |incoming, out| check_node_update_ms(incoming, out, alpha),
)
}
}
// Factory
#[derive(Debug, Clone)]
pub enum DecoderMethod {
BitFlipping,
SumProduct,
MinSum { scaling_factor: Llr },
}
pub fn build_decoder(
code: &LdpcCode,
method: DecoderMethod,
config: DecoderConfig,
) -> Box<dyn Decoder> {
match method {
DecoderMethod::BitFlipping => Box::new(BitFlippingDecoder::new(code, config)),
DecoderMethod::SumProduct => Box::new(SumProductDecoder::new(code, config)),
DecoderMethod::MinSum { scaling_factor } => {
Box::new(MinSumDecoder::new(code, config, scaling_factor))
}
}
}
// #[cfg(test)]
// mod tests {
// use super::*;
//
// #[test]
// fn test_phi_is_involution() {
// let x = 2.5_f64;
// assert!((phi(phi(x)) - x).abs() < 1e-9);
// }
//
// #[test]
// fn test_hard_decision() {
// assert_eq!(hard_decision(0.1), 0);
// assert_eq!(hard_decision(-0.1), 1);
// assert_eq!(hard_decision(0.0), 0);
// }
// }