t
This commit is contained in:
446
srctmp/decoder.rs
Normal file
446
srctmp/decoder.rs
Normal file
@ -0,0 +1,446 @@
|
||||
use crate::code::LdpcCode;
|
||||
use crate::graph::TannerGraph;
|
||||
use crate::matrix::SparseMatrixGF2;
|
||||
use crate::{BitVec, Gf2, Llr};
|
||||
|
||||
// Résultat
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DecoderResult {
|
||||
Converged(BitVec),
|
||||
MaxIterationsReached(BitVec),
|
||||
Failure,
|
||||
}
|
||||
|
||||
impl DecoderResult {
|
||||
pub fn codeword(&self) -> Option<&BitVec> {
|
||||
match self {
|
||||
DecoderResult::Converged(c) | DecoderResult::MaxIterationsReached(c) => Some(c),
|
||||
DecoderResult::Failure => None,
|
||||
}
|
||||
}
|
||||
pub fn is_success(&self) -> bool {
|
||||
matches!(self, DecoderResult::Converged(_))
|
||||
}
|
||||
}
|
||||
|
||||
// Configuration
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecoderConfig {
|
||||
pub max_iterations: usize,
|
||||
pub early_stopping: bool,
|
||||
}
|
||||
|
||||
impl Default for DecoderConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_iterations: 50,
|
||||
early_stopping: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Trait Decoder
|
||||
|
||||
pub trait Decoder: Send + Sync {
|
||||
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult;
|
||||
|
||||
fn decode_hard(&self, received: &[Gf2]) -> DecoderResult {
|
||||
let llr: Vec<Llr> = received
|
||||
.iter()
|
||||
.map(|&b| if b == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
self.decode(&llr)
|
||||
}
|
||||
}
|
||||
|
||||
// Primitives GF(2) et LLR
|
||||
|
||||
#[inline]
|
||||
pub fn hard_decision(llr: Llr) -> Gf2 {
|
||||
if llr >= 0.0 {
|
||||
0
|
||||
} else {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compute_syndrome(h: &SparseMatrixGF2, c: &[Gf2]) -> Vec<Gf2> {
|
||||
h.multiply_vec(c)
|
||||
}
|
||||
|
||||
// phi(x) = -ln(tanh(|x|/2)) involution de Sum-Product
|
||||
// phi(phi(x)) = x
|
||||
#[inline]
|
||||
fn phi(x: Llr) -> Llr {
|
||||
let ax = x.abs().max(1e-10);
|
||||
-((ax / 2.0).tanh().ln())
|
||||
}
|
||||
|
||||
// Mises à jour des nœuds
|
||||
|
||||
// Mise à jour Sum-Product du nœud de contrôle
|
||||
// R_{c→v} = φ(∑_{v'≠v} φ(|L_{v'→c}|)) × sign(∏_{v'≠v} sign(L_{v'→c}))
|
||||
fn check_node_update_sp(incoming: &[Llr], out: &mut [Llr]) {
|
||||
let phi_sum: Llr = incoming.iter().map(|&l| phi(l.abs())).sum();
|
||||
let sign_prod: Llr = incoming.iter().map(|&l| l.signum()).product();
|
||||
for (j, (&l, r)) in incoming.iter().zip(out.iter_mut()).enumerate() {
|
||||
let phi_excl = phi_sum - phi(l.abs());
|
||||
let sign_excl = sign_prod * l.signum();
|
||||
*r = sign_excl * phi(phi_excl);
|
||||
}
|
||||
}
|
||||
|
||||
// Mise à jour Min-Sum avec facteur de normalisation α
|
||||
// Approxime φ(∑ φ(|L|)) ≈ min(|L|).
|
||||
// alpha in [0.75, 0.875] compense le biais de Min-Sum brut
|
||||
fn check_node_update_ms(incoming: &[Llr], out: &mut [Llr], alpha: Llr) {
|
||||
let sign_prod: Llr = incoming.iter().map(|&l| l.signum()).product();
|
||||
// Précalcul des deux plus petites valeurs absolues
|
||||
let mut min1 = Llr::INFINITY;
|
||||
let mut min2 = Llr::INFINITY;
|
||||
let mut min1_idx = 0;
|
||||
for (j, &l) in incoming.iter().enumerate() {
|
||||
let al = l.abs();
|
||||
if al < min1 {
|
||||
min2 = min1;
|
||||
min1 = al;
|
||||
min1_idx = j;
|
||||
} else if al < min2 {
|
||||
min2 = al;
|
||||
}
|
||||
}
|
||||
for (j, (&l, r)) in incoming.iter().zip(out.iter_mut()).enumerate() {
|
||||
let min_excl = if j == min1_idx { min2 } else { min1 };
|
||||
let sign_excl = sign_prod * l.signum();
|
||||
*r = alpha * sign_excl * min_excl;
|
||||
}
|
||||
}
|
||||
|
||||
// Mise à jour du nœud de variable
|
||||
// L_{v→c} = L_channel(v) + ∑_{c'≠c} R_{c'→v}
|
||||
fn variable_node_update(ch_llr: Llr, incoming_c2v: &[Llr], out: &mut [Llr]) {
|
||||
let total: Llr = ch_llr + incoming_c2v.iter().sum::<Llr>();
|
||||
for ((&r, o)) in incoming_c2v.iter().zip(out.iter_mut()) {
|
||||
*o = total - r;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn posterior_llr(ch_llr: Llr, c2v_msgs: &[Llr]) -> Llr {
|
||||
ch_llr + c2v_msgs.iter().sum::<Llr>()
|
||||
}
|
||||
|
||||
// Messages internes
|
||||
// Indexés par (check_id, position_dans_liste_voisins) accès O(1)
|
||||
|
||||
struct Messages {
|
||||
v2c: Vec<Vec<Llr>>, // v2c[c][j] : var_neighbor(c)[j] -> check c
|
||||
c2v: Vec<Vec<Llr>>, // c2v[c][j] : check c -> var_neighbor(c)[j]
|
||||
}
|
||||
|
||||
impl Messages {
|
||||
fn new(graph: &TannerGraph) -> Self {
|
||||
let v2c = (0..graph.n_chk)
|
||||
.map(|c| vec![0.0; graph.chk_degree(c)])
|
||||
.collect();
|
||||
let c2v = (0..graph.n_chk)
|
||||
.map(|c| vec![0.0; graph.chk_degree(c)])
|
||||
.collect();
|
||||
Self { v2c, c2v }
|
||||
}
|
||||
}
|
||||
|
||||
// Table de correspondance, pour chaque (var, check), index dans la liste de voisins
|
||||
// Précalculée une fois à la construction du décodeur
|
||||
struct EdgeIndex {
|
||||
// var_pos_in_chk[c][j] : position de var_neighbor(c)[j] dans var_to_chk[v]
|
||||
var_pos_in_chk: Vec<Vec<usize>>,
|
||||
// chk_pos_in_var[v][i] : position de var_neighbor(v)[i] dans chk_to_var[c]
|
||||
chk_pos_in_var: Vec<Vec<usize>>,
|
||||
}
|
||||
|
||||
impl EdgeIndex {
|
||||
fn build(graph: &TannerGraph) -> Self {
|
||||
let var_pos_in_chk = (0..graph.n_chk)
|
||||
.map(|c| {
|
||||
graph
|
||||
.chk_neighbors(c)
|
||||
.iter()
|
||||
.map(|&v| graph.var_neighbors(v).iter().position(|&x| x == c).unwrap())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let chk_pos_in_var = (0..graph.n_var)
|
||||
.map(|v| {
|
||||
graph
|
||||
.var_neighbors(v)
|
||||
.iter()
|
||||
.map(|&c| graph.chk_neighbors(c).iter().position(|&x| x == v).unwrap())
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
Self {
|
||||
var_pos_in_chk,
|
||||
chk_pos_in_var,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Bit-Flipping
|
||||
|
||||
pub struct BitFlippingDecoder {
|
||||
graph: TannerGraph,
|
||||
h: SparseMatrixGF2,
|
||||
config: DecoderConfig,
|
||||
}
|
||||
|
||||
impl BitFlippingDecoder {
|
||||
pub fn new(code: &LdpcCode, config: DecoderConfig) -> Self {
|
||||
Self {
|
||||
graph: code.graph.clone(),
|
||||
h: code.h.clone(),
|
||||
config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for BitFlippingDecoder {
|
||||
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
|
||||
let mut bits: Vec<Gf2> = channel_llr.iter().map(|&l| hard_decision(l)).collect();
|
||||
|
||||
for _iter in 0..self.config.max_iterations {
|
||||
let syndrome = compute_syndrome(&self.h, &bits);
|
||||
if self.config.early_stopping && syndrome.iter().all(|&s| s == 0) {
|
||||
return DecoderResult::Converged(bits);
|
||||
}
|
||||
let mut unsatisfied = vec![0usize; self.graph.n_var];
|
||||
for c in 0..self.graph.n_chk {
|
||||
if syndrome[c] == 1 {
|
||||
for &v in self.graph.chk_neighbors(c) {
|
||||
unsatisfied[v] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut flipped = false;
|
||||
for v in 0..self.graph.n_var {
|
||||
if unsatisfied[v] > self.graph.var_degree(v) / 2 {
|
||||
bits[v] ^= 1;
|
||||
flipped = true;
|
||||
}
|
||||
}
|
||||
if !flipped {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let synd = compute_syndrome(&self.h, &bits);
|
||||
if synd.iter().all(|&s| s == 0) {
|
||||
DecoderResult::Converged(bits)
|
||||
} else {
|
||||
DecoderResult::MaxIterationsReached(bits)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Noyau BP partagé par SP et MinSum
|
||||
|
||||
fn bp_decode<F>(
|
||||
graph: &TannerGraph,
|
||||
h: &SparseMatrixGF2,
|
||||
config: &DecoderConfig,
|
||||
channel_llr: &[Llr],
|
||||
edge_idx: &EdgeIndex,
|
||||
check_update: F,
|
||||
) -> DecoderResult
|
||||
where
|
||||
F: Fn(&[Llr], &mut [Llr]),
|
||||
{
|
||||
let mut msgs = Messages::new(graph);
|
||||
|
||||
// Initialisation : v2c[c][j] = L_channel(var_neighbor(c)[j])
|
||||
for c in 0..graph.n_chk {
|
||||
for (j, &v) in graph.chk_neighbors(c).iter().enumerate() {
|
||||
msgs.v2c[c][j] = channel_llr[v];
|
||||
}
|
||||
}
|
||||
|
||||
for _iter in 0..config.max_iterations {
|
||||
// Mise à jour des check-nodes
|
||||
for c in 0..graph.n_chk {
|
||||
let v2c = msgs.v2c[c].clone();
|
||||
check_update(&v2c, &mut msgs.c2v[c]);
|
||||
}
|
||||
|
||||
// Mise à jour des var-nodes
|
||||
for v in 0..graph.n_var {
|
||||
let neighbors = graph.var_neighbors(v);
|
||||
// Rassembler les c2v entrants sur ce var-node
|
||||
let incoming: Vec<Llr> = neighbors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &c)| {
|
||||
let j = edge_idx.chk_pos_in_var[v][i];
|
||||
msgs.c2v[c][j]
|
||||
})
|
||||
.collect();
|
||||
let mut new_v2c = vec![0.0; neighbors.len()];
|
||||
variable_node_update(channel_llr[v], &incoming, &mut new_v2c);
|
||||
for (i, &c) in neighbors.iter().enumerate() {
|
||||
let j = edge_idx.chk_pos_in_var[v][i];
|
||||
msgs.v2c[c][j] = new_v2c[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Hard décision + arrêt
|
||||
if config.early_stopping {
|
||||
let bits = make_decision(graph, &msgs, channel_llr, edge_idx);
|
||||
if compute_syndrome(h, &bits).iter().all(|&s| s == 0) {
|
||||
return DecoderResult::Converged(bits);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let bits = make_decision(graph, &msgs, channel_llr, edge_idx);
|
||||
let synd = compute_syndrome(h, &bits);
|
||||
if synd.iter().all(|&s| s == 0) {
|
||||
DecoderResult::Converged(bits)
|
||||
} else {
|
||||
DecoderResult::MaxIterationsReached(bits)
|
||||
}
|
||||
}
|
||||
|
||||
fn make_decision(
|
||||
graph: &TannerGraph,
|
||||
msgs: &Messages,
|
||||
channel_llr: &[Llr],
|
||||
edge_idx: &EdgeIndex,
|
||||
) -> Vec<Gf2> {
|
||||
(0..graph.n_var)
|
||||
.map(|v| {
|
||||
let incoming: Vec<Llr> = graph
|
||||
.var_neighbors(v)
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &c)| {
|
||||
let j = edge_idx.chk_pos_in_var[v][i];
|
||||
msgs.c2v[c][j]
|
||||
})
|
||||
.collect();
|
||||
hard_decision(posterior_llr(channel_llr[v], &incoming))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Sum-Product
|
||||
|
||||
pub struct SumProductDecoder {
|
||||
graph: TannerGraph,
|
||||
h: SparseMatrixGF2,
|
||||
config: DecoderConfig,
|
||||
edge_idx: EdgeIndex,
|
||||
}
|
||||
|
||||
impl SumProductDecoder {
|
||||
pub fn new(code: &LdpcCode, config: DecoderConfig) -> Self {
|
||||
let edge_idx = EdgeIndex::build(&code.graph);
|
||||
Self {
|
||||
graph: code.graph.clone(),
|
||||
h: code.h.clone(),
|
||||
config,
|
||||
edge_idx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for SumProductDecoder {
|
||||
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
|
||||
bp_decode(
|
||||
&self.graph,
|
||||
&self.h,
|
||||
&self.config,
|
||||
channel_llr,
|
||||
&self.edge_idx,
|
||||
|incoming, out| check_node_update_sp(incoming, out),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Min-Sum
|
||||
|
||||
pub struct MinSumDecoder {
|
||||
graph: TannerGraph,
|
||||
h: SparseMatrixGF2,
|
||||
config: DecoderConfig,
|
||||
scaling_factor: Llr,
|
||||
edge_idx: EdgeIndex,
|
||||
}
|
||||
|
||||
impl MinSumDecoder {
|
||||
pub fn new(code: &LdpcCode, config: DecoderConfig, scaling_factor: Llr) -> Self {
|
||||
let edge_idx = EdgeIndex::build(&code.graph);
|
||||
Self {
|
||||
graph: code.graph.clone(),
|
||||
h: code.h.clone(),
|
||||
config,
|
||||
scaling_factor,
|
||||
edge_idx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for MinSumDecoder {
|
||||
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
|
||||
let alpha = self.scaling_factor;
|
||||
bp_decode(
|
||||
&self.graph,
|
||||
&self.h,
|
||||
&self.config,
|
||||
channel_llr,
|
||||
&self.edge_idx,
|
||||
move |incoming, out| check_node_update_ms(incoming, out, alpha),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Factory
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DecoderMethod {
|
||||
BitFlipping,
|
||||
SumProduct,
|
||||
MinSum { scaling_factor: Llr },
|
||||
}
|
||||
|
||||
pub fn build_decoder(
|
||||
code: &LdpcCode,
|
||||
method: DecoderMethod,
|
||||
config: DecoderConfig,
|
||||
) -> Box<dyn Decoder> {
|
||||
match method {
|
||||
DecoderMethod::BitFlipping => Box::new(BitFlippingDecoder::new(code, config)),
|
||||
DecoderMethod::SumProduct => Box::new(SumProductDecoder::new(code, config)),
|
||||
DecoderMethod::MinSum { scaling_factor } => {
|
||||
Box::new(MinSumDecoder::new(code, config, scaling_factor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
//
|
||||
// #[test]
|
||||
// fn test_phi_is_involution() {
|
||||
// let x = 2.5_f64;
|
||||
// assert!((phi(phi(x)) - x).abs() < 1e-9);
|
||||
// }
|
||||
//
|
||||
// #[test]
|
||||
// fn test_hard_decision() {
|
||||
// assert_eq!(hard_decision(0.1), 0);
|
||||
// assert_eq!(hard_decision(-0.1), 1);
|
||||
// assert_eq!(hard_decision(0.0), 0);
|
||||
// }
|
||||
// }
|
||||
Reference in New Issue
Block a user