411 lines
11 KiB
Rust
411 lines
11 KiB
Rust
use crate::code::LdpcCode;
|
|
use crate::graph::TannerGraph;
|
|
use crate::matrix::SparseMatrixGF2;
|
|
use crate::{BitVec, Gf2, Llr};
|
|
|
|
// Résultat
|
|
#[derive(Debug, Clone)]
|
|
pub enum DecoderResult {
|
|
Converged(BitVec),
|
|
MaxIterationsReached(BitVec),
|
|
Failure,
|
|
}
|
|
|
|
impl DecoderResult {
|
|
pub fn codeword(&self) -> Option<&BitVec> {
|
|
match self {
|
|
DecoderResult::Converged(c) | DecoderResult::MaxIterationsReached(c) => Some(c),
|
|
DecoderResult::Failure => None,
|
|
}
|
|
}
|
|
pub fn is_success(&self) -> bool {
|
|
matches!(self, DecoderResult::Converged(_))
|
|
}
|
|
}
|
|
|
|
// Configuration
|
|
#[derive(Debug, Clone)]
|
|
pub struct DecoderConfig {
|
|
pub max_iterations: usize,
|
|
pub early_stopping: bool,
|
|
}
|
|
|
|
impl Default for DecoderConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
max_iterations: 50,
|
|
early_stopping: true,
|
|
}
|
|
}
|
|
}
|
|
|
|
// Trait Decoder
|
|
pub trait Decoder: Send + Sync {
|
|
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult;
|
|
|
|
fn decode_hard(&self, received: &[Gf2]) -> DecoderResult {
|
|
let llr: Vec<Llr> = received
|
|
.iter()
|
|
.map(|&b| if b == 0 { 1.0 } else { -1.0 })
|
|
.collect();
|
|
self.decode(&llr)
|
|
}
|
|
}
|
|
|
|
// Primitives GF(2) et LLR
|
|
#[inline]
|
|
pub fn hard_decision(llr: Llr) -> Gf2 {
|
|
if llr >= 0.0 {
|
|
0
|
|
} else {
|
|
1
|
|
}
|
|
}
|
|
|
|
pub fn compute_syndrome(h: &SparseMatrixGF2, c: &[Gf2]) -> Vec<Gf2> {
|
|
h.multiply_vec(c)
|
|
}
|
|
|
|
#[inline]
|
|
fn phi(x: Llr) -> Llr {
|
|
let ax = x.abs().max(1e-10);
|
|
-((ax / 2.0).tanh().ln())
|
|
}
|
|
|
|
// Mises à jour des noeuds
|
|
|
|
// Mise à jour Sum-Product du noeud de contrôle
|
|
fn check_node_update_sp(incoming: &[Llr], out: &mut [Llr]) {
|
|
let phi_sum: Llr = incoming.iter().map(|&l| phi(l.abs())).sum();
|
|
let sign_prod: Llr = incoming.iter().map(|&l| l.signum()).product();
|
|
for (_j, (&l, r)) in incoming.iter().zip(out.iter_mut()).enumerate() {
|
|
let phi_excl = phi_sum - phi(l.abs());
|
|
let sign_excl = sign_prod * l.signum();
|
|
*r = sign_excl * phi(phi_excl);
|
|
}
|
|
}
|
|
|
|
// Mise à jour Min-Sum avec facteur de normalisation a
|
|
// alpha in [0.75, 0.875] compense le biais de Min-Sum brut
|
|
fn check_node_update_ms(incoming: &[Llr], out: &mut [Llr], alpha: Llr) {
|
|
let sign_prod: Llr = incoming.iter().map(|&l| l.signum()).product();
|
|
let mut min1 = Llr::INFINITY;
|
|
let mut min2 = Llr::INFINITY;
|
|
let mut min1_idx = 0;
|
|
for (j, &l) in incoming.iter().enumerate() {
|
|
let al = l.abs();
|
|
if al < min1 {
|
|
min2 = min1;
|
|
min1 = al;
|
|
min1_idx = j;
|
|
} else if al < min2 {
|
|
min2 = al;
|
|
}
|
|
}
|
|
for (j, (&l, r)) in incoming.iter().zip(out.iter_mut()).enumerate() {
|
|
let min_excl = if j == min1_idx { min2 } else { min1 };
|
|
let sign_excl = sign_prod * l.signum();
|
|
*r = alpha * sign_excl * min_excl;
|
|
}
|
|
}
|
|
|
|
// Mise à jour du noeud de variable
|
|
fn variable_node_update(ch_llr: Llr, incoming_c2v: &[Llr], out: &mut [Llr]) {
|
|
let total: Llr = ch_llr + incoming_c2v.iter().sum::<Llr>();
|
|
for (&r, o) in incoming_c2v.iter().zip(out.iter_mut()) {
|
|
*o = total - r;
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
fn posterior_llr(ch_llr: Llr, c2v_msgs: &[Llr]) -> Llr {
|
|
ch_llr + c2v_msgs.iter().sum::<Llr>()
|
|
}
|
|
|
|
// Messages internes
|
|
// Indexés par (check_id, position_dans_liste_voisins) (O(1))
|
|
struct Messages {
|
|
v2c: Vec<Vec<Llr>>,
|
|
c2v: Vec<Vec<Llr>>,
|
|
}
|
|
|
|
impl Messages {
|
|
fn new(graph: &TannerGraph) -> Self {
|
|
let v2c = (0..graph.n_chk)
|
|
.map(|c| vec![0.0; graph.chk_degree(c)])
|
|
.collect();
|
|
let c2v = (0..graph.n_chk)
|
|
.map(|c| vec![0.0; graph.chk_degree(c)])
|
|
.collect();
|
|
Self { v2c, c2v }
|
|
}
|
|
}
|
|
|
|
// Table de correspondance, pour chaque (var, check), index dans la liste de voisins
|
|
// Précalculé une fois à construction du décodeur
|
|
struct EdgeIndex {
|
|
var_pos_in_chk: Vec<Vec<usize>>,
|
|
chk_pos_in_var: Vec<Vec<usize>>,
|
|
}
|
|
|
|
impl EdgeIndex {
|
|
fn build(graph: &TannerGraph) -> Self {
|
|
let var_pos_in_chk = (0..graph.n_chk)
|
|
.map(|c| {
|
|
graph
|
|
.chk_neighbors(c)
|
|
.iter()
|
|
.map(|&v| graph.var_neighbors(v).iter().position(|&x| x == c).unwrap())
|
|
.collect()
|
|
})
|
|
.collect();
|
|
let chk_pos_in_var = (0..graph.n_var)
|
|
.map(|v| {
|
|
graph
|
|
.var_neighbors(v)
|
|
.iter()
|
|
.map(|&c| graph.chk_neighbors(c).iter().position(|&x| x == v).unwrap())
|
|
.collect()
|
|
})
|
|
.collect();
|
|
Self {
|
|
var_pos_in_chk,
|
|
chk_pos_in_var,
|
|
}
|
|
}
|
|
}
|
|
|
|
// Bit-Flipping
|
|
pub struct BitFlippingDecoder {
|
|
graph: TannerGraph,
|
|
h: SparseMatrixGF2,
|
|
config: DecoderConfig,
|
|
}
|
|
|
|
impl BitFlippingDecoder {
|
|
pub fn new(code: &LdpcCode, config: DecoderConfig) -> Self {
|
|
Self {
|
|
graph: code.graph.clone(),
|
|
h: code.h.clone(),
|
|
config,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Decoder for BitFlippingDecoder {
|
|
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
|
|
let mut bits: Vec<Gf2> = channel_llr.iter().map(|&l| hard_decision(l)).collect();
|
|
|
|
for _iter in 0..self.config.max_iterations {
|
|
let syndrome = compute_syndrome(&self.h, &bits);
|
|
if self.config.early_stopping && syndrome.iter().all(|&s| s == 0) {
|
|
return DecoderResult::Converged(bits);
|
|
}
|
|
let mut unsatisfied = vec![0usize; self.graph.n_var];
|
|
for c in 0..self.graph.n_chk {
|
|
if syndrome[c] == 1 {
|
|
for &v in self.graph.chk_neighbors(c) {
|
|
unsatisfied[v] += 1;
|
|
}
|
|
}
|
|
}
|
|
let mut flipped = false;
|
|
for v in 0..self.graph.n_var {
|
|
if unsatisfied[v] > self.graph.var_degree(v) / 2 {
|
|
bits[v] ^= 1;
|
|
flipped = true;
|
|
}
|
|
}
|
|
if !flipped {
|
|
break;
|
|
}
|
|
}
|
|
|
|
let synd = compute_syndrome(&self.h, &bits);
|
|
if synd.iter().all(|&s| s == 0) {
|
|
DecoderResult::Converged(bits)
|
|
} else {
|
|
DecoderResult::MaxIterationsReached(bits)
|
|
}
|
|
}
|
|
}
|
|
|
|
// BP
|
|
fn bp_decode<F>(
|
|
graph: &TannerGraph,
|
|
h: &SparseMatrixGF2,
|
|
config: &DecoderConfig,
|
|
channel_llr: &[Llr],
|
|
edge_idx: &EdgeIndex,
|
|
check_update: F,
|
|
) -> DecoderResult
|
|
where
|
|
F: Fn(&[Llr], &mut [Llr]),
|
|
{
|
|
let mut msgs = Messages::new(graph);
|
|
|
|
// Init
|
|
for c in 0..graph.n_chk {
|
|
for (j, &v) in graph.chk_neighbors(c).iter().enumerate() {
|
|
msgs.v2c[c][j] = channel_llr[v];
|
|
}
|
|
}
|
|
|
|
for _iter in 0..config.max_iterations {
|
|
// Maj des checknodes
|
|
for c in 0..graph.n_chk {
|
|
let v2c = msgs.v2c[c].clone();
|
|
check_update(&v2c, &mut msgs.c2v[c]);
|
|
}
|
|
|
|
// Maj des varnodes
|
|
for v in 0..graph.n_var {
|
|
let neighbors = graph.var_neighbors(v);
|
|
// Rassembler les c2v entrants sur ce varnode
|
|
let incoming: Vec<Llr> = neighbors
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, &c)| {
|
|
let j = edge_idx.chk_pos_in_var[v][i];
|
|
msgs.c2v[c][j]
|
|
})
|
|
.collect();
|
|
let mut new_v2c = vec![0.0; neighbors.len()];
|
|
variable_node_update(channel_llr[v], &incoming, &mut new_v2c);
|
|
for (i, &c) in neighbors.iter().enumerate() {
|
|
let j = edge_idx.chk_pos_in_var[v][i];
|
|
msgs.v2c[c][j] = new_v2c[i];
|
|
}
|
|
}
|
|
|
|
// Hard décision + arrêt
|
|
if config.early_stopping {
|
|
let bits = make_decision(graph, &msgs, channel_llr, edge_idx);
|
|
if compute_syndrome(h, &bits).iter().all(|&s| s == 0) {
|
|
return DecoderResult::Converged(bits);
|
|
}
|
|
}
|
|
}
|
|
|
|
let bits = make_decision(graph, &msgs, channel_llr, edge_idx);
|
|
let synd = compute_syndrome(h, &bits);
|
|
if synd.iter().all(|&s| s == 0) {
|
|
DecoderResult::Converged(bits)
|
|
} else {
|
|
DecoderResult::MaxIterationsReached(bits)
|
|
}
|
|
}
|
|
|
|
fn make_decision(
|
|
graph: &TannerGraph,
|
|
msgs: &Messages,
|
|
channel_llr: &[Llr],
|
|
edge_idx: &EdgeIndex,
|
|
) -> Vec<Gf2> {
|
|
(0..graph.n_var)
|
|
.map(|v| {
|
|
let incoming: Vec<Llr> = graph
|
|
.var_neighbors(v)
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, &c)| {
|
|
let j = edge_idx.chk_pos_in_var[v][i];
|
|
msgs.c2v[c][j]
|
|
})
|
|
.collect();
|
|
hard_decision(posterior_llr(channel_llr[v], &incoming))
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
// Sum-Product
|
|
pub struct SumProductDecoder {
|
|
graph: TannerGraph,
|
|
h: SparseMatrixGF2,
|
|
config: DecoderConfig,
|
|
edge_idx: EdgeIndex,
|
|
}
|
|
|
|
impl SumProductDecoder {
|
|
pub fn new(code: &LdpcCode, config: DecoderConfig) -> Self {
|
|
let edge_idx = EdgeIndex::build(&code.graph);
|
|
Self {
|
|
graph: code.graph.clone(),
|
|
h: code.h.clone(),
|
|
config,
|
|
edge_idx,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Decoder for SumProductDecoder {
|
|
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
|
|
bp_decode(
|
|
&self.graph,
|
|
&self.h,
|
|
&self.config,
|
|
channel_llr,
|
|
&self.edge_idx,
|
|
|incoming, out| check_node_update_sp(incoming, out),
|
|
)
|
|
}
|
|
}
|
|
|
|
// Min-Sum
|
|
pub struct MinSumDecoder {
|
|
graph: TannerGraph,
|
|
h: SparseMatrixGF2,
|
|
config: DecoderConfig,
|
|
scaling_factor: Llr,
|
|
edge_idx: EdgeIndex,
|
|
}
|
|
|
|
impl MinSumDecoder {
|
|
pub fn new(code: &LdpcCode, config: DecoderConfig, scaling_factor: Llr) -> Self {
|
|
let edge_idx = EdgeIndex::build(&code.graph);
|
|
Self {
|
|
graph: code.graph.clone(),
|
|
h: code.h.clone(),
|
|
config,
|
|
scaling_factor,
|
|
edge_idx,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Decoder for MinSumDecoder {
|
|
fn decode(&self, channel_llr: &[Llr]) -> DecoderResult {
|
|
let alpha = self.scaling_factor;
|
|
bp_decode(
|
|
&self.graph,
|
|
&self.h,
|
|
&self.config,
|
|
channel_llr,
|
|
&self.edge_idx,
|
|
move |incoming, out| check_node_update_ms(incoming, out, alpha),
|
|
)
|
|
}
|
|
}
|
|
|
|
// Factory
|
|
#[derive(Debug, Clone)]
|
|
pub enum DecoderMethod {
|
|
BitFlipping,
|
|
SumProduct,
|
|
MinSum { scaling_factor: Llr },
|
|
}
|
|
|
|
pub fn build_decoder(
|
|
code: &LdpcCode,
|
|
method: DecoderMethod,
|
|
config: DecoderConfig,
|
|
) -> Box<dyn Decoder> {
|
|
match method {
|
|
DecoderMethod::BitFlipping => Box::new(BitFlippingDecoder::new(code, config)),
|
|
DecoderMethod::SumProduct => Box::new(SumProductDecoder::new(code, config)),
|
|
DecoderMethod::MinSum { scaling_factor } => {
|
|
Box::new(MinSumDecoder::new(code, config, scaling_factor))
|
|
}
|
|
}
|
|
}
|