18 #ifndef __ARITHMTMASKING_H_
19 #define __ARITHMTMASKING_H_
30 m_nElements = numelements;
32 m_nMTBitLen =
sizeof(T) * 8;
33 memset(&m_nBitMask, 0xFF,
sizeof(T));
34 m_nOTByteLen =
sizeof(T) * m_nElements;
35 aesexpand = m_nOTByteLen > AES_BYTES;
38 m_bBuf = (BYTE*) malloc(
sizeof(BYTE) * AES_BYTES);
39 m_bCtrBuf = (BYTE*) malloc(
sizeof(BYTE) * AES_BYTES);
40 rndbuf.
CreateBytes(PadToMultiple(m_nOTByteLen, AES_BYTES));
54 void Mask(uint32_t progress, uint32_t processedOTs,
CBitVector* values,
CBitVector* snd_buf, BYTE version) {
57 if (progress % m_nMTBitLen != 0 || processedOTs % m_nMTBitLen != 0) {
58 cerr <<
"progress or processed OTs not divisible by MTBitLen, cannot guarantee correct result. Progress = " << progress <<
", processed OTs " << processedOTs
59 <<
", MTBitLen = " << m_nMTBitLen << endl;
62 T tmpval, diff, gtmpval[m_nElements];
64 for (uint32_t i = 0; i < m_nElements; i++)
67 #ifdef DEBUGARITHMTMASKING
68 cout <<
"Starting" << endl;
69 cout <<
"m_vInput.size= " << m_vInput->
GetSize() <<
" progress = " << progress <<
", mtbitlen = " << m_nMTBitLen << endl;
73 uint32_t startpos = (progress / (m_nMTBitLen * m_nElements));
75 T* input = (T*) m_vInput->
GetArr();
76 T* rndval = (T*) snd_buf[0].GetArr();
77 T* maskedval = (T*) snd_buf[1].GetArr();
79 T* retvals = ((T*) values[0].GetArr()) + startpos * m_nElements;
81 for (uint32_t mtid = startpos, i = 0, mtbit, j, ctr = 0; i < processedOTs; mtid++) {
83 #ifdef DEBUGARITHMTMASKING
84 cout <<
"mtid = " << mtid <<
"; getting from " << mtbit * m_nMTBitLen <<
" to " << m_nMTBitLen <<
", val = " << (UINT64_T) diff << endl;
87 for (mtbit = 0; mtbit < m_nMTBitLen; mtbit++, i++) {
88 for (j = 0; j < m_nElements; j++, ctr++) {
91 #ifdef DEBUGARITHMTMASKING
92 cout <<
"S: i = " << i <<
", diff " << (UINT64_T) diff <<
" tmpval = " << (UINT64_T)tmpval;
95 gtmpval[j] = gtmpval[j] + tmpval;
96 tmpval = diff - tmpval;
97 #ifdef DEBUGARITHMTMASKING
98 cout <<
", added = " << (UINT64_T) tmpval <<
", masked = " << (UINT64_T) snd_buf[1].
Get<T>(i * m_nMTBitLen, m_nMTBitLen) <<
", tmpsum mask = " << (UINT64_T) gtmpval[j] << endl;
101 maskedval[ctr] ^= tmpval;
107 for (j = 0; j < m_nElements; j++, retvals++) {
108 #ifdef DEBUGARITHMTMASKING
109 cout <<
"Computed Mask = " << (UINT64_T) gtmpval[j] << endl;
111 retvals[0] = gtmpval[j];
122 if (progress % m_nMTBitLen != 0 || processedOTs % m_nMTBitLen != 0) {
123 cerr <<
"progress or processed OTs not divisible by MTBitLen, cannot guarantee correct result. Progress = " << progress <<
", processed OTs " << processedOTs
124 <<
", MTBitLen = " << m_nMTBitLen << endl;
127 T tmpval, gtmpval[m_nElements];
128 uint32_t lim = progress + processedOTs;
129 BYTE* rcvbufptr = rcv_buf.
GetArr();
131 for (uint32_t i = 0; i < m_nElements; i++)
134 uint32_t startpos = progress / (m_nMTBitLen * m_nElements);
136 T* masks = (T*) tmpmasks.
GetArr();
137 T* rcvedvals = (T*) rcv_buf.
GetArr();
138 T* outvals = ((T*) output.
GetArr()) + startpos * m_nElements;
140 for (uint32_t mtid = startpos, i = progress, mtbit, j; i < lim; mtid++) {
141 #ifdef DEBUGARITHMTMASKING
142 cout <<
"Receiver val = " << (UINT64_T) tmpmasks.
Get<T>(mtid * m_nMTBitLen, m_nMTBitLen) <<
", bits = ";
143 tmpmasks.
Print(mtid * m_nMTBitLen, (mtid + 1) * m_nMTBitLen);
145 for (mtbit = 0; mtbit < m_nMTBitLen; mtbit++, i++, rcvbufptr += m_nOTByteLen) {
147 tmpmasks.
XORBytes(rcvbufptr, i * m_nOTByteLen, m_nOTByteLen);
148 for (j = 0; j < m_nElements; j++) {
149 tmpval = masks[i * m_nElements + j];
150 gtmpval[j] = gtmpval[j] + tmpval;
151 #ifdef DEBUGARITHMTMASKING
152 cout <<
"R: i = " << i <<
", tmpval " << (UINT64_T) tmpval <<
", tmpsum = " << (UINT64_T) gtmpval[j] <<
", choice = " << (UINT64_T) choices.
GetBitNoMask(i) << endl;
156 for (j = 0; j < m_nElements; j++) {
157 tmpval = masks[i * m_nElements + j];
160 #ifdef DEBUGARITHMTMASKING
161 cout <<
"R: i = " << i <<
", tmpval " << (UINT64_T) tmpval <<
", tmpsum = " << (UINT64_T) gtmpval[j] <<
", choice = " << (UINT64_T) choices.
GetBitNoMask(i) << endl;
168 for (j = 0; j < m_nElements; j++, outvals++) {
169 #ifdef DEBUGARITHMTMASKING
170 cout <<
"Computed = " << (UINT64_T) gtmpval[j] << endl;
172 outvals[0] = gtmpval[j];
179 void expandMask(
CBitVector& out, BYTE* sbp, uint32_t offset, uint32_t processedOTs, uint32_t bitlength,
crypto* crypt) {
184 BYTE* outptr = out.
GetArr() + offset * m_nOTByteLen;
185 for (uint32_t i = 0; i < processedOTs; i++, sbp += AES_KEY_BYTES, outptr += m_nOTByteLen) {
186 memcpy(outptr, sbp, m_nOTByteLen);
189 memset(m_bCtrBuf, 0, AES_BYTES);
190 uint32_t* counter = (uint32_t*) m_bCtrBuf;
191 for (uint32_t i = 0, rem; i < processedOTs; i++, sbp += AES_KEY_BYTES) {
193 crypt->init_aes_key(&tkey, sbp);
194 for (counter[0] = 0; counter[0] < ceil_divide(m_nOTByteLen, AES_BYTES); counter[0]++) {
195 crypt->encrypt(&tkey, m_bBuf, m_bCtrBuf, AES_BYTES);
196 rndbuf.
SetBytes(m_bBuf, counter[0] * AES_BYTES, AES_BYTES);
199 out.
SetBytes(rndbuf.
GetArr(), (offset + i) * m_nOTByteLen, m_nOTByteLen);
206 uint32_t m_nElements;
207 uint32_t m_nOTByteLen;
208 uint32_t m_nMTBitLen;
void PrintBinary()
Definition: cbitvector.h:837
BYTE GetBitNoMask(int idx)
Definition: cbitvector.h:467
void XORBytes(BYTE *p, int pos, int len)
Definition: cbitvector.cpp:269
int GetSize()
Definition: cbitvector.h:322
void delCBitVector()
Definition: cbitvector.h:172
T Get(int pos, int len)
Definition: cbitvector.h:577
void SetBytes(BYTE *p, int pos, int len)
Definition: cbitvector.cpp:302
void CreateBytes(uint64_t bytes)
Definition: cbitvector.h:216
void Print(int fromBit, int toBit)
Definition: cbitvector.cpp:347
BYTE * GetArr()
Definition: cbitvector.h:777
Definition: arithmtmasking.h:27
Definition: maskingfunction.h:25
Masking Function implementation.
Definition: cbitvector.h:123