ABY Framework  1.0
Arithmetic Bool Yao Framework
 All Classes Files Functions Variables Enumerations Enumerator Macros
arithmtmasking.h
Go to the documentation of this file.
1 
18 #ifndef __ARITHMTMASKING_H_
19 #define __ARITHMTMASKING_H_
20 
21 #include "maskingfunction.h"
22 
23 //#define DEBUGARITHMTMASKING
24 //TODO optimize
25 
26 template<typename T>
28 public:
29  ArithMTMasking(uint32_t numelements, CBitVector* in) {
30  m_nElements = numelements; //=K
31  m_vInput = in; //contains x and u, is 2-dim in case of the server and 1-dim in case of the client
32  m_nMTBitLen = sizeof(T) * 8;
33  memset(&m_nBitMask, 0xFF, sizeof(T));
34  m_nOTByteLen = sizeof(T) * m_nElements;
35  aesexpand = m_nOTByteLen > AES_BYTES;
36 
37  if (aesexpand) {
38  m_bBuf = (BYTE*) malloc(sizeof(BYTE) * AES_BYTES);
39  m_bCtrBuf = (BYTE*) malloc(sizeof(BYTE) * AES_BYTES);
40  rndbuf.CreateBytes(PadToMultiple(m_nOTByteLen, AES_BYTES));
41  }
42 
43  }
44  ;
45 
46  ~ArithMTMasking() {
47  free(m_bBuf);
48  free(m_bCtrBuf);
49  rndbuf.delCBitVector();
50  }
51  ;
52 
53  //In total K' OTs will be performed
54  void Mask(uint32_t progress, uint32_t processedOTs, CBitVector* values, CBitVector* snd_buf, BYTE version) {
55 
56  //progress and processedOTs should always be divisible by MTBitLen
57  if (progress % m_nMTBitLen != 0 || processedOTs % m_nMTBitLen != 0) {
58  cerr << "progress or processed OTs not divisible by MTBitLen, cannot guarantee correct result. Progress = " << progress << ", processed OTs " << processedOTs
59  << ", MTBitLen = " << m_nMTBitLen << endl;
60  }
61 
62  T tmpval, diff, gtmpval[m_nElements];
63 
64  for (uint32_t i = 0; i < m_nElements; i++)
65  gtmpval[i] = 0;
66 
67 #ifdef DEBUGARITHMTMASKING
68  cout << "Starting" << endl;
69  cout << "m_vInput.size= " << m_vInput->GetSize() << " progress = " << progress << ", mtbitlen = " << m_nMTBitLen << endl;
70  m_vInput->PrintBinary();
71 #endif
72 
73  uint32_t startpos = (progress / (m_nMTBitLen * m_nElements));
74 
75  T* input = (T*) m_vInput->GetArr();
76  T* rndval = (T*) snd_buf[0].GetArr();
77  T* maskedval = (T*) snd_buf[1].GetArr();
78 
79  T* retvals = ((T*) values[0].GetArr()) + startpos * m_nElements;
80 
81  for (uint32_t mtid = startpos, i = 0, mtbit, j, ctr = 0; i < processedOTs; mtid++) {
82  diff = input[mtid]; //m_vInput->Get<T>(mtid * m_nMTBitLen, m_nMTBitLen);
83 #ifdef DEBUGARITHMTMASKING
84  cout << "mtid = " << mtid << "; getting from " << mtbit * m_nMTBitLen << " to " << m_nMTBitLen << ", val = " << (UINT64_T) diff << endl;
85 #endif
86 
87  for (mtbit = 0; mtbit < m_nMTBitLen; mtbit++, i++) {
88  for (j = 0; j < m_nElements; j++, ctr++) {
89  //Get randomly generated mask from snd_buf[0]
90  tmpval = rndval[ctr];
91 #ifdef DEBUGARITHMTMASKING
92  cout << "S: i = " << i << ", diff " << (UINT64_T) diff << " tmpval = " << (UINT64_T)tmpval;
93 #endif
94  //Add random mask to the already generated masks for this MT
95  gtmpval[j] = gtmpval[j] + tmpval;
96  tmpval = diff - tmpval;
97 #ifdef DEBUGARITHMTMASKING
98  cout << ", added = " << (UINT64_T) tmpval << ", masked = " << (UINT64_T) snd_buf[1].Get<T>(i * m_nMTBitLen, m_nMTBitLen) << ", tmpsum mask = " << (UINT64_T) gtmpval[j] << endl;
99 #endif
100  //Mask the resulting correlation with the second OT result
101  maskedval[ctr] ^= tmpval;
102  }
103  diff = diff << 1;
104  }
105 
106  //Write out the result into values[0]
107  for (j = 0; j < m_nElements; j++, retvals++) {
108 #ifdef DEBUGARITHMTMASKING
109  cout << "Computed Mask = " << (UINT64_T) gtmpval[j] << endl;
110 #endif
111  retvals[0] = gtmpval[j];
112  gtmpval[j] = 0;
113  }
114  }
115  }
116  ;
117 
118  //rcv_buf holds the masked values that were sent by the sender, output holds the masks that were generated by the receiver
119  void UnMask(uint32_t progress, uint32_t processedOTs, CBitVector& choices, CBitVector& output, CBitVector& rcv_buf, CBitVector& tmpmasks, BYTE version) {
120 
121  //progress and processedOTs should always be divisible by MTBitLen
122  if (progress % m_nMTBitLen != 0 || processedOTs % m_nMTBitLen != 0) {
123  cerr << "progress or processed OTs not divisible by MTBitLen, cannot guarantee correct result. Progress = " << progress << ", processed OTs " << processedOTs
124  << ", MTBitLen = " << m_nMTBitLen << endl;
125  }
126 
127  T tmpval, gtmpval[m_nElements];
128  uint32_t lim = progress + processedOTs;
129  BYTE* rcvbufptr = rcv_buf.GetArr();
130 
131  for (uint32_t i = 0; i < m_nElements; i++)
132  gtmpval[i] = 0;
133 
134  uint32_t startpos = progress / (m_nMTBitLen * m_nElements);
135 
136  T* masks = (T*) tmpmasks.GetArr();
137  T* rcvedvals = (T*) rcv_buf.GetArr();
138  T* outvals = ((T*) output.GetArr()) + startpos * m_nElements;
139 
140  for (uint32_t mtid = startpos, i = progress, mtbit, j; i < lim; mtid++) {
141 #ifdef DEBUGARITHMTMASKING
142  cout << "Receiver val = " << (UINT64_T) tmpmasks.Get<T>(mtid * m_nMTBitLen, m_nMTBitLen) << ", bits = ";
143  tmpmasks.Print(mtid * m_nMTBitLen, (mtid + 1) * m_nMTBitLen);
144 #endif
145  for (mtbit = 0; mtbit < m_nMTBitLen; mtbit++, i++, rcvbufptr += m_nOTByteLen) {
146  if (choices.GetBitNoMask(i)) {
147  tmpmasks.XORBytes(rcvbufptr, i * m_nOTByteLen, m_nOTByteLen);
148  for (j = 0; j < m_nElements; j++) {
149  tmpval = masks[i * m_nElements + j];
150  gtmpval[j] = gtmpval[j] + tmpval;
151 #ifdef DEBUGARITHMTMASKING
152  cout << "R: i = " << i << ", tmpval " << (UINT64_T) tmpval << ", tmpsum = " << (UINT64_T) gtmpval[j] << ", choice = " << (UINT64_T) choices.GetBitNoMask(i) << endl;
153 #endif
154  }
155  } else {
156  for (j = 0; j < m_nElements; j++) {
157  tmpval = masks[i * m_nElements + j];
158  gtmpval[j] =
159  gtmpval[j] - tmpval;
160 #ifdef DEBUGARITHMTMASKING
161  cout << "R: i = " << i << ", tmpval " << (UINT64_T) tmpval << ", tmpsum = " << (UINT64_T) gtmpval[j] << ", choice = " << (UINT64_T) choices.GetBitNoMask(i) << endl;
162 #endif
163  }
164  }
165  }
166 
167  //Write out the result into values[0]
168  for (j = 0; j < m_nElements; j++, outvals++) {
169 #ifdef DEBUGARITHMTMASKING
170  cout << "Computed = " << (UINT64_T) gtmpval[j] << endl;
171 #endif
172  outvals[0] = gtmpval[j];
173  gtmpval[j] = 0;
174  }
175  }
176  }
177  ;
178 
179  void expandMask(CBitVector& out, BYTE* sbp, uint32_t offset, uint32_t processedOTs, uint32_t bitlength, crypto* crypt) {
180 
181  //the CBitVector to store the random values in
182 
183  if (!aesexpand) {
184  BYTE* outptr = out.GetArr() + offset * m_nOTByteLen;
185  for (uint32_t i = 0; i < processedOTs; i++, sbp += AES_KEY_BYTES, outptr += m_nOTByteLen) {
186  memcpy(outptr, sbp, m_nOTByteLen);
187  }
188  } else {
189  memset(m_bCtrBuf, 0, AES_BYTES);
190  uint32_t* counter = (uint32_t*) m_bCtrBuf;
191  for (uint32_t i = 0, rem; i < processedOTs; i++, sbp += AES_KEY_BYTES) {
192  //Generate sufficient random bits
193  crypt->init_aes_key(&tkey, sbp);
194  for (counter[0] = 0; counter[0] < ceil_divide(m_nOTByteLen, AES_BYTES); counter[0]++) {
195  crypt->encrypt(&tkey, m_bBuf, m_bCtrBuf, AES_BYTES);
196  rndbuf.SetBytes(m_bBuf, counter[0] * AES_BYTES, AES_BYTES);
197  }
198  //Copy random bits into output vector
199  out.SetBytes(rndbuf.GetArr(), (offset + i) * m_nOTByteLen, m_nOTByteLen);
200  }
201  }
202  }
203 
204 private:
205  CBitVector* m_vInput;
206  uint32_t m_nElements;
207  uint32_t m_nOTByteLen;
208  uint32_t m_nMTBitLen;
209  uint64_t m_nBitMask;
210  BYTE* m_bBuf;
211  BYTE* m_bCtrBuf;
212  AES_KEY_CTX tkey;
213  BOOL aesexpand;
214  CBitVector rndbuf;
215 };
216 
217 #endif /* __ARITHMTMASKING_H_ */
void PrintBinary()
Definition: cbitvector.h:837
BYTE GetBitNoMask(int idx)
Definition: cbitvector.h:467
void XORBytes(BYTE *p, int pos, int len)
Definition: cbitvector.cpp:269
Definition: crypto.h:58
int GetSize()
Definition: cbitvector.h:322
void delCBitVector()
Definition: cbitvector.h:172
T Get(int pos, int len)
Definition: cbitvector.h:577
void SetBytes(BYTE *p, int pos, int len)
Definition: cbitvector.cpp:302
void CreateBytes(uint64_t bytes)
Definition: cbitvector.h:216
void Print(int fromBit, int toBit)
Definition: cbitvector.cpp:347
BYTE * GetArr()
Definition: cbitvector.h:777
Definition: arithmtmasking.h:27
Definition: maskingfunction.h:25
Masking Function implementation.
Definition: cbitvector.h:123