Skip to content

Commit

Permalink
added better stern fq
Browse files Browse the repository at this point in the history
  • Loading branch information
Floyd committed Mar 12, 2024
1 parent 210aae9 commit d08e93e
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 148 deletions.
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)

# options
option(USE_TCMALLOC "Enable tcmalloc" OFF)
option(USE_SANITIZER "Enable memory sanitizer" OFF)
option(USE_SANITIZER "Enable memory sanitizer" ON)
option(USE_NOINLINE "Enable -fno-inline" OFF)
option(USE_LTO "Enable -flto" OFF)
option(USE_PROFILE "Enable profile guided optmization" OFF)
Expand Down Expand Up @@ -99,11 +99,11 @@ if(USE_TCMALLOC)
set(MALLOC_LIBRARY_FLAGS "tcmalloc")
endif()

if(USE_SANITIZER AND COMPILER_SUPPORTS_SANITIZE_ADDRESS)
#if(USE_SANITIZER AND COMPILER_SUPPORTS_SANITIZE_ADDRESS)
message(STATUS "Using address sanitize")
# possible sanitizers = -fsanitize=[address,leak,thread,memory,undefined]
set(SANITIZER_FLAGS "-fsanitize=address -fsanitize=pointer-compare -fno-omit-frame-pointer")
endif()
#endif()

if(USE_NOINLINE AND COMPILER_SUPPORTS_NO_INLINE)
message(STATUS "not inlining")
Expand Down
33 changes: 14 additions & 19 deletions src/fq/stern.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class FqStern : public ISDInstance<uint64_t, isd> {
using IndexType = TypeTemplate<HM_nrb * HM_bs>;

/// needed list stuff
constexpr static uint32_t enum_length = (k + l + 1u) >> 1u;
constexpr static uint32_t enum_length = (k + l) >> 1u;
constexpr static uint32_t enum_offset = k + l - enum_length;
constexpr static size_t list_size = compute_combinations_fq_chase_list_size<enum_length, q, p>();
using List = Parallel_List_FullElement_T<Element>;
Expand Down Expand Up @@ -115,8 +115,7 @@ class FqStern : public ISDInstance<uint64_t, isd> {
/// init lists
L1 = new List(list_size, threads);
L2 = new List(list_size, threads);
ASSERT(L1 != nullptr);
ASSERT(L2 != nullptr);
ASSERT((L1 != nullptr) && (L2 != nullptr));

///// init hashmap
hm = new HM;
Expand Down Expand Up @@ -144,33 +143,30 @@ class FqStern : public ISDInstance<uint64_t, isd> {
const size_t end = L2->end_pos(tid);

Label tmp;
l_type synd = Compress(ws);
for (size_t i = start; i < end; ++i) {
l_type data = NegateCompress(L2->at(i).label);
data = Label::template add_T<l_type>(data, synd);
data = Label::template add_T<l_type>(data, syndrome);

/// search in HM
HM_LoadType load; // TODO maybe hier sowas wie fast_load_type
HM_LoadType load;
IndexType pos = hm->find(data, load);

for (uint64_t j = pos; j < pos + load; j++) {
const IndexType index = hm->ptr(j)[0];
pos += 1;

/// TODO not really nice
Label::add(tmp, L1->at(index).label, L2->at(i).label);
Label::sub(tmp, ws, tmp);

//std::cout << std::endl;
//L2->at(i).label.print();
//L1->at(index).label.print();
//tmp.print();
//ws.print();
//std::cout << i << " " << index << std::endl;
// std::cout << std::endl;
// L2->at(i).label.print();
// L1->at(index).label.print();
// tmp.print();
// ws.print();
// std::cout << i << " " << index << std::endl;

/// some debug checks
for (uint32_t j = 0; j < l; ++j) {
ASSERT(tmp.get(j) == 0);
for (uint32_t s = 0; s < l; ++s) {
ASSERT(tmp.get(s) == 0);
}

/// if this checks passes we found a solution
Expand Down Expand Up @@ -204,7 +200,6 @@ class FqStern : public ISDInstance<uint64_t, isd> {
tmpe.set(label_solution_to_recover.get(l + i), 0, (n-k-l) - i - 1);
}

tmpe.print();
for (uint32_t i = 0; i < k+l; ++i) {
tmpe.set(value_solution_to_recover.get(i), 0, (n-k-l)+i);
}
Expand All @@ -227,9 +222,9 @@ class FqStern : public ISDInstance<uint64_t, isd> {
while (not_found && (loops < config.loops)) {
ISD::step();

#pragma omp parallel default(none) shared(std::cout,L1,not_found,loops) num_threads(threads)
// #pragma omp parallel default(none) shared(std::cout,L1,not_found,loops) num_threads(threads)
{
const uint32_t tid = Thread::get_tid();
const uint32_t tid = 0; //Thread::get_tid();
init_list(tid);
Thread::sync();
find_collisions(tid);
Expand Down
216 changes: 118 additions & 98 deletions src/fq/stern_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,22 @@ class FqSternV2 : public ISDInstance<uint64_t, isd> {
constexpr static uint32_t p = isd.p;

constexpr static uint32_t threads = config.threads;
constexpr static size_t final_list_size = 1024;

public:
using ISD = ISDInstance<uint64_t, isd>;
using Error = ISD::Error;
using Label = ISD::Label;
using Value = ISD::Value;
using Element = ISD::Element;
using limb_type = ISD::limb_type;
using PCSubMatrix_T = ISD::PCSubMatrix_T;
using ISD::A,ISD::H,ISD::wA,ISD::wAT,ISD::HT,ISD::s,ISD::ws,ISD::e,ISD::syndrome,ISD::P,ISD::not_found,ISD::loops,ISD::ghz;
using ISD::cycles, ISD::gaus_cycles, ISD::periodic_print, ISD::HashMasp_BucketSize, ISD::cL;
using ISD::cycles, ISD::gaus_cycles, ISD::periodic_print, ISD::HashMasp_BucketSize;

constexpr static uint32_t HM_nrb = ISD::HashMasp_BucketSize(l);
constexpr static uint32_t HM_bs = config.HM_bucketsize;
constexpr static uint32_t NR_HT_LIMBS = PCSubMatrix_T::limbs_per_row();

constexpr static bool packed = ISD::packed;

Expand All @@ -50,51 +54,9 @@ class FqSternV2 : public ISDInstance<uint64_t, isd> {
l_type *lHT;

/// needed list stuff
constexpr static uint32_t enum_length = (k + l + 1u + isd.epsilon) >> 1u;
constexpr static uint32_t enum_length = (k + l + isd.epsilon) >> 1u;
constexpr static uint32_t enum_offset = k + l - enum_length;
// TODO maybe move to ISD
constexpr static size_t enum_size = compute_combinations_fq_chase_list_size<n,q,w>();

/// compress in the case of non packed data containers
/// e.g. removes all the zeros from the containers
constexpr inline static __uint128_t Compress(const Label &label) noexcept {
/// some security measurements:
if constexpr (qbits*l > 128) {
ASSERT(false && "not implemented");
}

/// easy case
if constexpr (packed) {
constexpr __uint128_t mask = qbits*l == 128 ? __uint128_t(-1ull) : (__uint128_t(1ull) << (qbits*l)) - __uint128_t(1ull);
using TT = LogTypeTemplate<qbits*l>;

/// NOTE: that we fetch the first 128bits (and not the last, where we would assume
/// normally the l bit window), as we assume we swapped all rows within the parity
/// check matrix. Therefor the last `l` bit of every syndrome are now the first.
const TT a = *((TT *)label.ptr());
const __uint128_t aa = a;
return aa&mask;
}

// important to init with zero here.
__uint128_t ret = 0;

#pragma unroll
for (uint32_t i = 0u; i < l; ++i) {
ret ^= (label.get(i) << (qbits*i));
}

return ret;
}

/// this extractor is needed, to be able to search on the second
/// list for elements, which are equal zero summed together
/// \return the negative of label[lower, q*l*nr_window) shifted to zero
constexpr inline static __uint128_t NegateCompress(const Label &label) noexcept {
Label tmp = label;
tmp.neg();
return Compress(tmp);
}
constexpr static size_t enum_size = compute_combinations_fq_chase_list_size<enum_length, 2, p>();

using V1 = CollisionType<l_type, uint16_t, 1>;
constexpr static SimpleHashMapConfig simpleHashMapConfig{
Expand All @@ -106,26 +68,99 @@ class FqSternV2 : public ISDInstance<uint64_t, isd> {
using HM_DataType_IndexType = V1::index_type;
using HM_DataType = V1;


/// changelist stuff
chase<(config.k+config.l)/2 + config.epsilon, config.p, config.q> c{};
using cle = std::pair<uint16_t, uint16_t>;
std::vector<cle> cL;

///
size_t cfls = 0;
alignas(256) std::array<uint32_t, final_list_size> final_list_left;
alignas(256) std::array<uint32_t, final_list_size> final_list_right;

/// list entries of the current permutation if a solution was found.
size_t solutions[2*p] = {0};

Value value_solution_to_recover;
Label label_solution_to_recover;

/// base constructor
FqSternV2() noexcept : ISDInstance<uint64_t, isd>(true) {
FqSternV2() noexcept {
constexpr size_t size_lHT = roundToAligned<1024>(sizeof(l_type) * (k+l));
lHT =(l_type *)cryptanalysislib::aligned_alloc(1024, size_lHT);
ASSERT(lHT);

hm = new HM{};
ASSERT(hm);
hm->print();


cL.resize(enum_size);
size_t ctr = 0;
c.enumerate([&, this](const uint16_t p1, const uint16_t p2){
cL[ctr] = cle{p1, p2};
ctr += 1;
});
ASSERT(ctr == (enum_size - 1));
}

/// free all the memory
~FqSternV2() {
delete hm;
}

bool compute_finale_list() noexcept {
alignas(32) uint16_t left[p], right[p];
for (size_t m = 0; m < cfls; ++m) {
biject<enum_length, p>(final_list_left[m], left);
biject<enum_length, p>(final_list_right[m], right);

/// TODO the following things are wrong:
/// - biject is not correct, we need to match the +q-1 offset
/// - missing lpart = 0 checks
auto climb = ws.ptr(0);
for (uint16_t j = 0; j < p; j++) {
const auto tmp = Label::add_T(HT.limb(left[j], 0), HT.limb(right[j], 0));
climb = Label::add_T(climb, tmp);
}

uint32_t wt = Label::popcnt_T(climb);
if (likely(wt > (w - (2*p)))) {
continue;
}


for (uint32_t i = 1; i < NR_HT_LIMBS; i++) {
climb = ws.ptr(i);
for (uint16_t j = 0; j < p; j++) {
const auto tmp = Label::add_T(HT.limb(left[j], i), HT.limb(right[j], i));
climb = Label::add_T(climb, tmp);
}

wt += cryptanalysislib::popcount::popcount(climb);
}

if ((wt <= w - (2*p)) && not_found) {
not_found = false;
cycles = cpucycles() - cycles;
for (uint16_t j = 0; j < p; ++j) {
solutions[j*p + 0] = left[j];
solutions[j*p + 1] = right[j];
}

cfls = 0;
return true;
}
}

cfls = 0;
return false;
}

///
inline void init_list(const uint32_t tid) {
// hm->clear();
hm->clear();

l_type tmp = 0;
for (uint32_t i = 0; i < p; ++i) {
Expand All @@ -134,7 +169,7 @@ class FqSternV2 : public ISDInstance<uint64_t, isd> {

for (size_t i = 0; i < enum_size; ++i) {
const uint16_t ci = cL[i].second;
for (uint32_t j = 0; j < q; ++j) {
for (uint32_t j = 0; j < (q - 1); ++j) {
hm->insert(tmp, HM_DataType::create(tmp, (HM_DataType_IndexType *)&i));
tmp = Label::template add_T<l_type>(tmp, lHT[ci]);
}
Expand All @@ -144,54 +179,38 @@ class FqSternV2 : public ISDInstance<uint64_t, isd> {
}
}

// void find_collisions(const uint32_t tid) {
// const size_t start = L2->start_pos(tid);
// const size_t end = L2->end_pos(tid);
//
// Label tmp;
// l_type synd = Compress(ws);
// for (size_t i = start; i < end; ++i) {
// l_type data = NegateCompress(L2->at(i).label);
// data = Label::template add_T<l_type>(data, synd);
//
// /// search in HM
// HM_LoadType load; // TODO maybe hier sowas wie fast_load_type
// IndexType pos = hm->find(data, load);
//
// for (uint64_t j = pos; j < pos + load; j++) {
// const IndexType index = hm->ptr(j)[0];
// pos += 1;
//
// /// TODO not really nice
// Label::add(tmp, L1->at(index).label, L2->at(i).label);
// Label::sub(tmp, ws, tmp);
//
// //std::cout << std::endl;
// //L2->at(i).label.print();
// //L1->at(index).label.print();
// //tmp.print();
// //ws.print();
// //std::cout << i << " " << index << std::endl;
//
// /// some debug checks
// for (uint32_t j = 0; j < l; ++j) {
// ASSERT(tmp.get(j) == 0);
// }
//
// /// if this checks passes we found a solution
// if (unlikely(tmp.popcnt() <= (w - 2*p))) {
// not_found = false;
// label_solution_to_recover = tmp;
// Value::add(value_solution_to_recover, L2->at(i).value, L1->at(index).value);
// goto finished;
// }
// }
// }
//
// /// forward jumps = best jumps
// finished:
// return;
// }
void find_collisions(const uint32_t tid) {
l_type tmp = syndrome;
for (uint32_t i = 0; i < p; ++i) {
tmp = Label::template add_T<l_type>(tmp, lHT[i]);
}

for (size_t i = 0; i < enum_size; ++i) {
const uint16_t ci = cL[i].second;
for (uint32_t m = 0; m < (q - 1); ++m) {
HM_LoadType load;
IndexType pos = hm->find(tmp, load);
for (uint64_t j = pos; j < pos + load; j++) {
const IndexType index = hm->ptr(j).index[0];

final_list_left[cfls] = index;
final_list_right[cfls] = i;
cfls += 1;

if (cfls >= final_list_size) {
compute_finale_list();
}
}

// next element
tmp = Label::template add_T<l_type>(tmp, lHT[ci]);
}

// next element in the chase sequence
tmp = Label::template add_T<l_type>(tmp, lHT[cL[i].first]);
tmp = Label::template add_T<l_type>(tmp, lHT[cL[i].second]);
}
}

///
void rebuild_solution() {
Expand Down Expand Up @@ -235,11 +254,12 @@ class FqSternV2 : public ISDInstance<uint64_t, isd> {

// #pragma omp parallel default(none) shared(std::cout,L1,not_found,loops) num_threads(threads)
{
const uint32_t tid = Thread::get_tid();
const uint32_t tid = 0; //Thread::get_tid();
init_list(tid);

Thread::sync();
//find_collisions(tid);
find_collisions(tid);
compute_finale_list();
}
}

Expand Down
Loading

0 comments on commit d08e93e

Please sign in to comment.