447 lines
12 KiB
Rust
447 lines
12 KiB
Rust
use crate::code::LdpcCode;
|
||
use crate::graph::TannerGraph;
|
||
use crate::matrix::SparseMatrixGF2;
|
||
use crate::{BitVec, Gf2, Llr};
|
||
|
||
// Résultat
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub enum DecoderResult {
|
||
Converged(BitVec),
|
||
MaxIterationsReached(BitVec),
|
||
Failure,
|
||
}
|
||
|
||
impl DecoderResult {
|
||
pub fn codeword(&self) -> Option<&BitVec> {
|
||
match self {
|
||
DecoderResult::Converged(c) | DecoderResult::MaxIterationsReached(c) => Some(c),
|
||
DecoderResult::Failure => None,
|
||
}
|
||
}
|
||
pub fn is_success(&self) -> bool {
|
||
matches!(self, DecoderResult::Converged(_))
|
||
}
|
||
}
|
||
|
||
// Configuration
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct DecoderConfig {
|
||
pub max_iterations: usize,
|
||
pub early_stopping: bool,
|
||
}
|
||
|
||
impl Default for DecoderConfig {
|
||
fn default() -> Self {
|
||
Self {
|
||
max_iterations: 50,
|
||
early_stopping: true,
|
||
}
|
||
}
|
||
}
|
||
|
||
// Trait Decoder
|
||
|
||
pub trait Decoder: Send + Sync {
|
||
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult;
|
||
|
||
fn decode_hard(&self, received: &[Gf2]) -> DecoderResult {
|
||
let llr: Vec<Llr> = received
|
||
.iter()
|
||
.map(|&b| if b == 0 { 1.0 } else { -1.0 })
|
||
.collect();
|
||
self.decode(&llr)
|
||
}
|
||
}
|
||
|
||
// Primitives GF(2) et LLR
|
||
|
||
#[inline]
|
||
pub fn hard_decision(llr: Llr) -> Gf2 {
|
||
if llr >= 0.0 {
|
||
0
|
||
} else {
|
||
1
|
||
}
|
||
}
|
||
|
||
pub fn compute_syndrome(h: &SparseMatrixGF2, c: &[Gf2]) -> Vec<Gf2> {
|
||
h.multiply_vec(c)
|
||
}
|
||
|
||
// phi(x) = -ln(tanh(|x|/2)) involution de Sum-Product
|
||
// phi(phi(x)) = x
|
||
#[inline]
|
||
fn phi(x: Llr) -> Llr {
|
||
let ax = x.abs().max(1e-10);
|
||
-((ax / 2.0).tanh().ln())
|
||
}
|
||
|
||
// Mises à jour des nœuds
|
||
|
||
// Mise à jour Sum-Product du nœud de contrôle
|
||
// R_{c→v} = φ(∑_{v'≠v} φ(|L_{v'→c}|)) × sign(∏_{v'≠v} sign(L_{v'→c}))
|
||
fn check_node_update_sp(incoming: &[Llr], out: &mut [Llr]) {
|
||
let phi_sum: Llr = incoming.iter().map(|&l| phi(l.abs())).sum();
|
||
let sign_prod: Llr = incoming.iter().map(|&l| l.signum()).product();
|
||
for (j, (&l, r)) in incoming.iter().zip(out.iter_mut()).enumerate() {
|
||
let phi_excl = phi_sum - phi(l.abs());
|
||
let sign_excl = sign_prod * l.signum();
|
||
*r = sign_excl * phi(phi_excl);
|
||
}
|
||
}
|
||
|
||
// Mise à jour Min-Sum avec facteur de normalisation α
|
||
// Approxime φ(∑ φ(|L|)) ≈ min(|L|).
|
||
// alpha in [0.75, 0.875] compense le biais de Min-Sum brut
|
||
fn check_node_update_ms(incoming: &[Llr], out: &mut [Llr], alpha: Llr) {
|
||
let sign_prod: Llr = incoming.iter().map(|&l| l.signum()).product();
|
||
// Précalcul des deux plus petites valeurs absolues
|
||
let mut min1 = Llr::INFINITY;
|
||
let mut min2 = Llr::INFINITY;
|
||
let mut min1_idx = 0;
|
||
for (j, &l) in incoming.iter().enumerate() {
|
||
let al = l.abs();
|
||
if al < min1 {
|
||
min2 = min1;
|
||
min1 = al;
|
||
min1_idx = j;
|
||
} else if al < min2 {
|
||
min2 = al;
|
||
}
|
||
}
|
||
for (j, (&l, r)) in incoming.iter().zip(out.iter_mut()).enumerate() {
|
||
let min_excl = if j == min1_idx { min2 } else { min1 };
|
||
let sign_excl = sign_prod * l.signum();
|
||
*r = alpha * sign_excl * min_excl;
|
||
}
|
||
}
|
||
|
||
// Mise à jour du nœud de variable
|
||
// L_{v→c} = L_channel(v) + ∑_{c'≠c} R_{c'→v}
|
||
fn variable_node_update(ch_llr: Llr, incoming_c2v: &[Llr], out: &mut [Llr]) {
|
||
let total: Llr = ch_llr + incoming_c2v.iter().sum::<Llr>();
|
||
for ((&r, o)) in incoming_c2v.iter().zip(out.iter_mut()) {
|
||
*o = total - r;
|
||
}
|
||
}
|
||
|
||
#[inline]
|
||
fn posterior_llr(ch_llr: Llr, c2v_msgs: &[Llr]) -> Llr {
|
||
ch_llr + c2v_msgs.iter().sum::<Llr>()
|
||
}
|
||
|
||
// Messages internes
|
||
// Indexés par (check_id, position_dans_liste_voisins) accès O(1)
|
||
|
||
struct Messages {
|
||
v2c: Vec<Vec<Llr>>, // v2c[c][j] : var_neighbor(c)[j] -> check c
|
||
c2v: Vec<Vec<Llr>>, // c2v[c][j] : check c -> var_neighbor(c)[j]
|
||
}
|
||
|
||
impl Messages {
|
||
fn new(graph: &TannerGraph) -> Self {
|
||
let v2c = (0..graph.n_chk)
|
||
.map(|c| vec![0.0; graph.chk_degree(c)])
|
||
.collect();
|
||
let c2v = (0..graph.n_chk)
|
||
.map(|c| vec![0.0; graph.chk_degree(c)])
|
||
.collect();
|
||
Self { v2c, c2v }
|
||
}
|
||
}
|
||
|
||
// Table de correspondance, pour chaque (var, check), index dans la liste de voisins
|
||
// Précalculée une fois à la construction du décodeur
|
||
struct EdgeIndex {
|
||
// var_pos_in_chk[c][j] : position de var_neighbor(c)[j] dans var_to_chk[v]
|
||
var_pos_in_chk: Vec<Vec<usize>>,
|
||
// chk_pos_in_var[v][i] : position de var_neighbor(v)[i] dans chk_to_var[c]
|
||
chk_pos_in_var: Vec<Vec<usize>>,
|
||
}
|
||
|
||
impl EdgeIndex {
|
||
fn build(graph: &TannerGraph) -> Self {
|
||
let var_pos_in_chk = (0..graph.n_chk)
|
||
.map(|c| {
|
||
graph
|
||
.chk_neighbors(c)
|
||
.iter()
|
||
.map(|&v| graph.var_neighbors(v).iter().position(|&x| x == c).unwrap())
|
||
.collect()
|
||
})
|
||
.collect();
|
||
let chk_pos_in_var = (0..graph.n_var)
|
||
.map(|v| {
|
||
graph
|
||
.var_neighbors(v)
|
||
.iter()
|
||
.map(|&c| graph.chk_neighbors(c).iter().position(|&x| x == v).unwrap())
|
||
.collect()
|
||
})
|
||
.collect();
|
||
Self {
|
||
var_pos_in_chk,
|
||
chk_pos_in_var,
|
||
}
|
||
}
|
||
}
|
||
|
||
// Bit-Flipping
|
||
|
||
pub struct BitFlippingDecoder {
|
||
graph: TannerGraph,
|
||
h: SparseMatrixGF2,
|
||
config: DecoderConfig,
|
||
}
|
||
|
||
impl BitFlippingDecoder {
|
||
pub fn new(code: &LdpcCode, config: DecoderConfig) -> Self {
|
||
Self {
|
||
graph: code.graph.clone(),
|
||
h: code.h.clone(),
|
||
config,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Decoder for BitFlippingDecoder {
|
||
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
|
||
let mut bits: Vec<Gf2> = channel_llr.iter().map(|&l| hard_decision(l)).collect();
|
||
|
||
for _iter in 0..self.config.max_iterations {
|
||
let syndrome = compute_syndrome(&self.h, &bits);
|
||
if self.config.early_stopping && syndrome.iter().all(|&s| s == 0) {
|
||
return DecoderResult::Converged(bits);
|
||
}
|
||
let mut unsatisfied = vec![0usize; self.graph.n_var];
|
||
for c in 0..self.graph.n_chk {
|
||
if syndrome[c] == 1 {
|
||
for &v in self.graph.chk_neighbors(c) {
|
||
unsatisfied[v] += 1;
|
||
}
|
||
}
|
||
}
|
||
let mut flipped = false;
|
||
for v in 0..self.graph.n_var {
|
||
if unsatisfied[v] > self.graph.var_degree(v) / 2 {
|
||
bits[v] ^= 1;
|
||
flipped = true;
|
||
}
|
||
}
|
||
if !flipped {
|
||
break;
|
||
}
|
||
}
|
||
|
||
let synd = compute_syndrome(&self.h, &bits);
|
||
if synd.iter().all(|&s| s == 0) {
|
||
DecoderResult::Converged(bits)
|
||
} else {
|
||
DecoderResult::MaxIterationsReached(bits)
|
||
}
|
||
}
|
||
}
|
||
|
||
// Noyau BP partagé par SP et MinSum
|
||
|
||
fn bp_decode<F>(
|
||
graph: &TannerGraph,
|
||
h: &SparseMatrixGF2,
|
||
config: &DecoderConfig,
|
||
channel_llr: &[Llr],
|
||
edge_idx: &EdgeIndex,
|
||
check_update: F,
|
||
) -> DecoderResult
|
||
where
|
||
F: Fn(&[Llr], &mut [Llr]),
|
||
{
|
||
let mut msgs = Messages::new(graph);
|
||
|
||
// Initialisation : v2c[c][j] = L_channel(var_neighbor(c)[j])
|
||
for c in 0..graph.n_chk {
|
||
for (j, &v) in graph.chk_neighbors(c).iter().enumerate() {
|
||
msgs.v2c[c][j] = channel_llr[v];
|
||
}
|
||
}
|
||
|
||
for _iter in 0..config.max_iterations {
|
||
// Mise à jour des check-nodes
|
||
for c in 0..graph.n_chk {
|
||
let v2c = msgs.v2c[c].clone();
|
||
check_update(&v2c, &mut msgs.c2v[c]);
|
||
}
|
||
|
||
// Mise à jour des var-nodes
|
||
for v in 0..graph.n_var {
|
||
let neighbors = graph.var_neighbors(v);
|
||
// Rassembler les c2v entrants sur ce var-node
|
||
let incoming: Vec<Llr> = neighbors
|
||
.iter()
|
||
.enumerate()
|
||
.map(|(i, &c)| {
|
||
let j = edge_idx.chk_pos_in_var[v][i];
|
||
msgs.c2v[c][j]
|
||
})
|
||
.collect();
|
||
let mut new_v2c = vec![0.0; neighbors.len()];
|
||
variable_node_update(channel_llr[v], &incoming, &mut new_v2c);
|
||
for (i, &c) in neighbors.iter().enumerate() {
|
||
let j = edge_idx.chk_pos_in_var[v][i];
|
||
msgs.v2c[c][j] = new_v2c[i];
|
||
}
|
||
}
|
||
|
||
// Hard décision + arrêt
|
||
if config.early_stopping {
|
||
let bits = make_decision(graph, &msgs, channel_llr, edge_idx);
|
||
if compute_syndrome(h, &bits).iter().all(|&s| s == 0) {
|
||
return DecoderResult::Converged(bits);
|
||
}
|
||
}
|
||
}
|
||
|
||
let bits = make_decision(graph, &msgs, channel_llr, edge_idx);
|
||
let synd = compute_syndrome(h, &bits);
|
||
if synd.iter().all(|&s| s == 0) {
|
||
DecoderResult::Converged(bits)
|
||
} else {
|
||
DecoderResult::MaxIterationsReached(bits)
|
||
}
|
||
}
|
||
|
||
fn make_decision(
|
||
graph: &TannerGraph,
|
||
msgs: &Messages,
|
||
channel_llr: &[Llr],
|
||
edge_idx: &EdgeIndex,
|
||
) -> Vec<Gf2> {
|
||
(0..graph.n_var)
|
||
.map(|v| {
|
||
let incoming: Vec<Llr> = graph
|
||
.var_neighbors(v)
|
||
.iter()
|
||
.enumerate()
|
||
.map(|(i, &c)| {
|
||
let j = edge_idx.chk_pos_in_var[v][i];
|
||
msgs.c2v[c][j]
|
||
})
|
||
.collect();
|
||
hard_decision(posterior_llr(channel_llr[v], &incoming))
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
// Sum-Product
|
||
|
||
pub struct SumProductDecoder {
|
||
graph: TannerGraph,
|
||
h: SparseMatrixGF2,
|
||
config: DecoderConfig,
|
||
edge_idx: EdgeIndex,
|
||
}
|
||
|
||
impl SumProductDecoder {
|
||
pub fn new(code: &LdpcCode, config: DecoderConfig) -> Self {
|
||
let edge_idx = EdgeIndex::build(&code.graph);
|
||
Self {
|
||
graph: code.graph.clone(),
|
||
h: code.h.clone(),
|
||
config,
|
||
edge_idx,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Decoder for SumProductDecoder {
|
||
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
|
||
bp_decode(
|
||
&self.graph,
|
||
&self.h,
|
||
&self.config,
|
||
channel_llr,
|
||
&self.edge_idx,
|
||
|incoming, out| check_node_update_sp(incoming, out),
|
||
)
|
||
}
|
||
}
|
||
|
||
// Min-Sum
|
||
|
||
pub struct MinSumDecoder {
|
||
graph: TannerGraph,
|
||
h: SparseMatrixGF2,
|
||
config: DecoderConfig,
|
||
scaling_factor: Llr,
|
||
edge_idx: EdgeIndex,
|
||
}
|
||
|
||
impl MinSumDecoder {
|
||
pub fn new(code: &LdpcCode, config: DecoderConfig, scaling_factor: Llr) -> Self {
|
||
let edge_idx = EdgeIndex::build(&code.graph);
|
||
Self {
|
||
graph: code.graph.clone(),
|
||
h: code.h.clone(),
|
||
config,
|
||
scaling_factor,
|
||
edge_idx,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Decoder for MinSumDecoder {
|
||
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
|
||
let alpha = self.scaling_factor;
|
||
bp_decode(
|
||
&self.graph,
|
||
&self.h,
|
||
&self.config,
|
||
channel_llr,
|
||
&self.edge_idx,
|
||
move |incoming, out| check_node_update_ms(incoming, out, alpha),
|
||
)
|
||
}
|
||
}
|
||
|
||
// Factory
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub enum DecoderMethod {
|
||
BitFlipping,
|
||
SumProduct,
|
||
MinSum { scaling_factor: Llr },
|
||
}
|
||
|
||
pub fn build_decoder(
|
||
code: &LdpcCode,
|
||
method: DecoderMethod,
|
||
config: DecoderConfig,
|
||
) -> Box<dyn Decoder> {
|
||
match method {
|
||
DecoderMethod::BitFlipping => Box::new(BitFlippingDecoder::new(code, config)),
|
||
DecoderMethod::SumProduct => Box::new(SumProductDecoder::new(code, config)),
|
||
DecoderMethod::MinSum { scaling_factor } => {
|
||
Box::new(MinSumDecoder::new(code, config, scaling_factor))
|
||
}
|
||
}
|
||
}
|
||
|
||
// #[cfg(test)]
|
||
// mod tests {
|
||
// use super::*;
|
||
//
|
||
// #[test]
|
||
// fn test_phi_is_involution() {
|
||
// let x = 2.5_f64;
|
||
// assert!((phi(phi(x)) - x).abs() < 1e-9);
|
||
// }
|
||
//
|
||
// #[test]
|
||
// fn test_hard_decision() {
|
||
// assert_eq!(hard_decision(0.1), 0);
|
||
// assert_eq!(hard_decision(-0.1), 1);
|
||
// assert_eq!(hard_decision(0.0), 0);
|
||
// }
|
||
// }
|