KSeExpr  4.0.4.0
ExprLLVMCodeGeneration.cpp
Go to the documentation of this file.
1 // SPDX-FileCopyrightText: 2011-2019 Disney Enterprises, Inc.
2 // SPDX-License-Identifier: LicenseRef-Apache-2.0
3 // SPDX-FileCopyrightText: 2020 L. E. Segovia <amy@amyspark.me>
4 // SPDX-License-Identifier: GPL-3.0-or-later
5 
6 #include "ExprConfig.h"
7 
8 #if defined(SEEXPR_ENABLE_LLVM)
9 #include <array>
10 
11 #include "ExprFunc.h"
12 #include "ExprLLVM.h"
13 #include "ExprLLVMAll.h"
14 #include "ExprNode.h"
15 #include "StringUtils.h"
16 #include "VarBlock.h"
17 
18 using namespace llvm;
19 using namespace KSeExpr;
20 
21 // TODO: Use ordered or unordered float comparison?
22 // TODO: factor out commonly used llvm types
23 // TODO: factor out integer/double constant creation
24 namespace
25 {
26 Function *llvm_getFunction(LLVM_BUILDER Builder)
27 {
28  return Builder.GetInsertBlock()->getParent();
29 }
30 
31 Module *llvm_getModule(LLVM_BUILDER Builder)
32 {
33  return llvm_getFunction(Builder)->getParent();
34 }
35 
37 std::string llvmTypeString(llvm::Type *type)
38 {
39  std::string myString;
40  llvm::raw_string_ostream rawStream(myString);
41  type->print(rawStream);
42  return rawStream.str();
43 }
44 
45 bool isVarArg(ExprFuncStandard::FuncType seFuncType)
46 {
47  return seFuncType == ExprFuncStandard::FUNCN || seFuncType == ExprFuncStandard::FUNCNV || seFuncType == ExprFuncStandard::FUNCNVV;
48 }
49 
50 bool isReturnVector(ExprFuncStandard::FuncType seFuncType)
51 {
52  return seFuncType == ExprFuncStandard::FUNC1VV || seFuncType == ExprFuncStandard::FUNC2VV || seFuncType == ExprFuncStandard::FUNCNVV;
53 }
54 
55 bool isTakeOnlyDoubleArg(ExprFuncStandard::FuncType seFuncType)
56 {
57  return seFuncType <= ExprFuncStandard::FUNC6 || seFuncType == ExprFuncStandard::FUNCN;
58 }
59 
60 FunctionType *getSeExprFuncStandardLLVMType(ExprFuncStandard::FuncType sft, LLVMContext &llvmContext)
61 {
62  assert(sft != ExprFuncStandard::NONE);
63 
64  Type *intType = Type::getInt32Ty(llvmContext);
65  Type *doubleType = Type::getDoubleTy(llvmContext);
66  Type *doublePtrType = PointerType::getUnqual(Type::getDoubleTy(llvmContext));
67  Type *voidType = Type::getVoidTy(llvmContext);
68  FunctionType *FT = nullptr;
69 
70  if (sft <= ExprFuncStandard::FUNC6) {
71  std::vector<Type *> paramTypes;
72  switch (sft) {
73  case ExprFuncStandard::FUNC6:
74  paramTypes.push_back(doubleType);
75  case ExprFuncStandard::FUNC5:
76  paramTypes.push_back(doubleType);
77  case ExprFuncStandard::FUNC4:
78  paramTypes.push_back(doubleType);
79  case ExprFuncStandard::FUNC3:
80  paramTypes.push_back(doubleType);
81  case ExprFuncStandard::FUNC2:
82  paramTypes.push_back(doubleType);
83  case ExprFuncStandard::FUNC1:
84  paramTypes.push_back(doubleType);
85  case ExprFuncStandard::FUNC0:
86  default:
87  FT = FunctionType::get(doubleType, paramTypes, false);
88  }
89  } else if (sft == ExprFuncStandard::FUNC1V) {
90  std::array<Type*, 1> paramTypes = {doublePtrType};
91  FT = FunctionType::get(doubleType, paramTypes, false);
92  } else if (sft == ExprFuncStandard::FUNC2V) {
93  std::array<Type *, 2> paramTypes = {doublePtrType, doublePtrType};
94  FT = FunctionType::get(doubleType, paramTypes, false);
95  } else if (sft == ExprFuncStandard::FUNC1VV) {
96  std::array<Type *, 2> paramTypes = {doublePtrType, doublePtrType};
97  FT = FunctionType::get(voidType, paramTypes, false);
98  } else if (sft == ExprFuncStandard::FUNC2VV) {
99  std::array<Type *, 3> paramTypes = {doublePtrType, doublePtrType, doublePtrType};
100  FT = FunctionType::get(voidType, paramTypes, false);
101  } else if (sft == ExprFuncStandard::FUNCN) {
102  std::array<Type *, 2> paramTypes = {intType, doublePtrType};
103  FT = FunctionType::get(doubleType, paramTypes, false);
104  } else if (sft == ExprFuncStandard::FUNCNV) {
105  std::array<Type *, 2> paramTypes = {intType, doublePtrType};
106  FT = FunctionType::get(doubleType, paramTypes, false);
107  } else if (sft == ExprFuncStandard::FUNCNVV) {
108  std::array<Type *, 3> paramTypes = {doublePtrType, intType, doublePtrType};
109  FT = FunctionType::get(voidType, paramTypes, false);
110  } else
111  assert(false);
112 
113  return FT;
114 }
115 
116 LLVM_VALUE CreateCall(LLVM_BUILDER Builder, LLVM_VALUE addrVal, ArrayRef<Value *> args)
117 {
118 #if LLVM_VERSION_MAJOR >= 11
119  // LLVM 11 wants the desired function signature forcibly.
120  // However, Disney covered it in a layer of casts, which I have to undo.
121  auto *funcCast = llvm::cast<llvm::CastInst>(addrVal);
122  assert(funcCast && "ERROR! The callee value is not a pointer cast!");
123  auto *funcPtr = llvm::cast<llvm::PointerType>(funcCast->getDestTy());
124  assert(funcPtr && "ERROR! The callee value does not contain a function!");
125  auto *TY = llvm::cast<llvm::FunctionType>(funcPtr->getElementType());
126  assert(TY && "ERROR! The callee value does not return a function signature!");
127  return Builder.CreateCall(TY, addrVal, args);
128 #else
129  return Builder.CreateCall(addrVal, args);
130 #endif
131 }
132 
133 Type *createLLVMTyForSeExprType(LLVMContext &llvmContext, const ExprType& seType)
134 {
135  if (seType.isFP()) {
136  int dim = seType.dim();
137 #if LLVM_VERSION_MAJOR >= 10
138  return dim == 1 ? Type::getDoubleTy(llvmContext) : VectorType::get(Type::getDoubleTy(llvmContext), dim, false);
139 #else
140  return dim == 1 ? Type::getDoubleTy(llvmContext) : VectorType::get(Type::getDoubleTy(llvmContext), dim);
141 #endif
142  } else if (seType.isString()) {
143  static_assert(sizeof(char*) == 8, "Expect 64-bit pointers");
144  return Type::getInt8PtrTy(llvmContext);
145  }
146  assert(!"unknown SeExpr type encountered"); // unknown type
147  return nullptr;
148 }
149 
150 // Copy a scalar "val" to a vector of "dim" length
151 LLVM_VALUE createVecVal(LLVM_BUILDER Builder, LLVM_VALUE val, unsigned dim)
152 {
153  LLVMContext &llvmContext = Builder.getContext();
154 #if LLVM_VERSION_MAJOR >= 10
155  VectorType *doubleVecTy = VectorType::get(Type::getDoubleTy(llvmContext), dim, false);
156 #else
157  VectorType *doubleVecTy = VectorType::get(Type::getDoubleTy(llvmContext), dim);
158 #endif
159  LLVM_VALUE vecVal = UndefValue::get(doubleVecTy);
160  for (unsigned i = 0; i < dim; i++)
161  vecVal = Builder.CreateInsertElement(vecVal, val, ConstantInt::get(Type::getInt32Ty(llvmContext), i));
162  return vecVal;
163 }
164 
165 // Copy a vector "val" to a vector of the same length
166 LLVM_VALUE createVecVal(LLVM_BUILDER Builder, ArrayRef<LLVM_VALUE> val, const std::string &name = "")
167 {
168  if (val.empty())
169  return nullptr;
170 
171  LLVMContext &llvmContext = Builder.getContext();
172  unsigned dim = val.size();
173 #if LLVM_VERSION_MAJOR >= 10
174  VectorType *elemType = VectorType::get(val[0]->getType(), dim, false);
175 #else
176  VectorType *elemType = VectorType::get(val[0]->getType(), dim);
177 #endif
178  LLVM_VALUE vecVal = UndefValue::get(elemType);
179  for (unsigned i = 0; i < dim; i++)
180  vecVal = Builder.CreateInsertElement(vecVal, val[i], ConstantInt::get(Type::getInt32Ty(llvmContext), i), name);
181  return vecVal;
182 }
183 
184 LLVM_VALUE createVecValFromAlloca(LLVM_BUILDER Builder, AllocaInst *destPtr, unsigned vecLen)
185 {
186  Type *destTy = destPtr->getType()->getPointerElementType();
187  assert(destTy->isDoubleTy() || destTy->isArrayTy());
188  std::vector<LLVM_VALUE> vals;
189 
190  for (unsigned i = 0; i < vecLen; ++i) {
191  LLVM_VALUE ptr = destTy->isDoubleTy() ? CREATE_CONST_GEP1_32(Builder, destPtr, i) : Builder.CreateConstGEP2_32(nullptr, destPtr, 0, i);
192  vals.push_back(CREATE_LOAD(Builder, ptr));
193  }
194 
195  return createVecVal(Builder, vals);
196 }
197 
199 inline unsigned int getVectorNumElements(llvm::Type *ty)
200 {
201 #if LLVM_VERSION_MAJOR >= 11
202  assert(ty && ty->isVectorTy() && "This is not a vector type!");
203  auto *VT = llvm::cast<llvm::VectorType>(ty);
204 #if LLVM_VERSION_MAJOR >= 13
205  return VT->getElementCount().getKnownMinValue();
206 #else
207  return VT->getNumElements();
208 #endif
209 #else
210  return ty->getVectorNumElements();
211 #endif
212 }
213 
214 LLVM_VALUE getFirstElement(LLVM_VALUE V, IRBuilder<> &Builder)
215 {
216  Type *VTy = V->getType();
217  if (VTy->isDoubleTy())
218  return V;
219  if (VTy->isPointerTy())
220  return V;
221 
222  assert(VTy->isVectorTy());
223  LLVMContext &llvmContext = Builder.getContext();
224  LLVM_VALUE zero = ConstantInt::get(Type::getInt32Ty(llvmContext), 0);
225  return Builder.CreateExtractElement(V, zero);
226 }
227 
228 LLVM_VALUE promoteToTy(LLVM_VALUE val, Type *destTy, LLVM_BUILDER Builder)
229 {
230  Type *srcTy = val->getType();
231  if (srcTy == destTy)
232  return val;
233 
234  if (destTy->isDoubleTy())
235  return val;
236 
237  return createVecVal(Builder, val, getVectorNumElements(destTy));
238 }
239 
240 AllocaInst *createAllocaInst(LLVM_BUILDER Builder, Type *ty, unsigned arraySize = 1, const StringRef &varName = "")
241 {
242  // move builder to first position of entry BB
243  BasicBlock *entryBB = &llvm_getFunction(Builder)->getEntryBlock();
244  IRBuilder<>::InsertPoint oldIP = Builder.saveIP();
245  if (!entryBB->empty())
246  Builder.SetInsertPoint(&entryBB->front());
247  else
248  Builder.SetInsertPoint(entryBB);
249 
250  // allocate stack memory and store value to it.
251  LLVMContext &llvmContext = Builder.getContext();
252  LLVM_VALUE arraySizeVal = ConstantInt::get(Type::getInt32Ty(llvmContext), arraySize);
253  AllocaInst *varPtr = Builder.CreateAlloca(ty, arraySizeVal, static_cast<std::string>(varName));
254  // restore builder insertion position
255  Builder.restoreIP(oldIP);
256  return varPtr;
257 }
258 
259 AllocaInst *createArray(LLVM_BUILDER Builder, Type *ty, unsigned arraySize, const std::string &varName = "")
260 {
261  // move builder to first position of entry BB
262  BasicBlock *entryBB = &llvm_getFunction(Builder)->getEntryBlock();
263  IRBuilder<>::InsertPoint oldIP = Builder.saveIP();
264  if (!entryBB->empty())
265  Builder.SetInsertPoint(&entryBB->front());
266  else
267  Builder.SetInsertPoint(entryBB);
268 
269  // allocate stack memory and store value to it.
270  ArrayType *arrayTy = ArrayType::get(ty, arraySize);
271  AllocaInst *varPtr = Builder.CreateAlloca(arrayTy, nullptr, varName);
272  // restore builder insertion position
273  Builder.restoreIP(oldIP);
274  return varPtr;
275 }
276 
277 std::pair<LLVM_VALUE, LLVM_VALUE> promoteBinaryOperandsToAppropriateVector(LLVM_BUILDER Builder, LLVM_VALUE op1, LLVM_VALUE op2)
278 {
279  Type *op1Ty = op1->getType();
280  Type *op2Ty = op2->getType();
281  if (op1Ty == op2Ty)
282  return std::make_pair(op1, op2);
283 
284  LLVM_VALUE toPromote = op1;
285  LLVM_VALUE target = op2;
286  if (op1Ty->isVectorTy())
287  std::swap(toPromote, target);
288 
289  assert(target->getType()->isVectorTy());
290 
291  unsigned dim = getVectorNumElements(target->getType());
292  LLVM_VALUE vecVal = createVecVal(Builder, toPromote, dim);
293 
294  if (op1Ty->isVectorTy())
295  op2 = vecVal;
296  else
297  op1 = vecVal;
298 
299  return std::make_pair(op1, op2);
300 }
301 
302 LLVM_VALUE promoteOperand(LLVM_BUILDER Builder, const ExprType& refType, LLVM_VALUE val)
303 {
304  Type *valTy = val->getType();
305  if (refType.isFP() && refType.dim() > 1 && !valTy->isVectorTy()) {
306  return createVecVal(Builder, val, refType.dim());
307  } else {
308  return val;
309  }
310 }
311 
312 AllocaInst *storeVectorToDoublePtr(LLVM_BUILDER Builder, LLVM_VALUE vecVal)
313 {
314  LLVMContext &llvmContext = Builder.getContext();
315  AllocaInst *doublePtr = createAllocaInst(Builder, Type::getDoubleTy(llvmContext), getVectorNumElements(vecVal->getType()));
316  for (unsigned i = 0; i < 3; ++i) {
317  LLVM_VALUE idx = ConstantInt::get(Type::getInt32Ty(llvmContext), i);
318  LLVM_VALUE val = Builder.CreateExtractElement(vecVal, idx);
319  LLVM_VALUE ptr = CREATE_CONST_GEP1_32(Builder, doublePtr, i);
320  Builder.CreateStore(val, ptr);
321  }
322  return doublePtr;
323 }
324 
325 std::vector<LLVM_VALUE> codegenFuncCallArgs(LLVM_BUILDER Builder, const ExprFuncNode *funcNode)
326 {
327  std::vector<LLVM_VALUE> args;
328  args.reserve(funcNode->numChildren());
329  for (int i = 0; i < funcNode->numChildren(); ++i)
330  args.push_back(funcNode->child(i)->codegen(Builder));
331  return args;
332 }
333 
334 std::vector<LLVM_VALUE> promoteArgs(std::vector<LLVM_VALUE> args, LLVM_BUILDER Builder, FunctionType *llvmFuncType)
335 {
336  std::vector<LLVM_VALUE> ret;
337  for (unsigned i = 0; i < args.size(); ++i)
338  ret.push_back(promoteToTy(args[i], llvmFuncType->getParamType(i), Builder));
339  return ret;
340 }
341 
342 std::vector<LLVM_VALUE> promoteArgs(std::vector<LLVM_VALUE> args, LLVM_BUILDER Builder, ExprFuncStandard::FuncType seFuncType)
343 {
344  if (isTakeOnlyDoubleArg(seFuncType))
345  return args;
346 
347  LLVMContext &llvmContext = Builder.getContext();
348 #if LLVM_VERSION_MAJOR >= 10
349  VectorType *destTy = VectorType::get(Type::getDoubleTy(llvmContext), 3, false);
350 #else
351  VectorType *destTy = VectorType::get(Type::getDoubleTy(llvmContext), 3);
352 #endif
353  std::vector<LLVM_VALUE> ret;
354  ret.reserve(args.size());
355  for (auto & arg : args)
356  ret.push_back(promoteToTy(arg, destTy, Builder));
357  return ret;
358 }
359 
360 std::vector<LLVM_VALUE> replaceVecArgWithDoublePointer(LLVM_BUILDER Builder, std::vector<LLVM_VALUE> args)
361 {
362  for (auto & arg : args)
363  if (arg->getType()->isVectorTy())
364  arg = storeVectorToDoublePtr(Builder, arg);
365  return args;
366 }
367 
368 std::vector<LLVM_VALUE> convertArgsToPointerAndLength(LLVM_BUILDER Builder, std::vector<LLVM_VALUE> actualArgs, ExprFuncStandard::FuncType seFuncType)
369 {
370  assert(isVarArg(seFuncType));
371 
372  LLVMContext &llvmContext = Builder.getContext();
373  unsigned numArgs = actualArgs.size();
374 
375  // type of arg should be either double or double*(aka. vector).
376  for (unsigned i = 0; i < numArgs; ++i)
377  assert(actualArgs[i]->getType()->isDoubleTy() || actualArgs[i]->getType() == Type::getDoublePtrTy(llvmContext));
378 
379  std::vector<LLVM_VALUE> args;
380  // push "int n"
381  args.push_back(ConstantInt::get(Type::getInt32Ty(llvmContext), numArgs));
382 
383  if (seFuncType == ExprFuncStandard::FUNCN) {
384  AllocaInst *doublePtr = createAllocaInst(Builder, Type::getDoubleTy(llvmContext), numArgs);
385  for (unsigned i = 0; i < numArgs; ++i) {
386  LLVM_VALUE ptr = CREATE_CONST_GEP1_32(Builder, doublePtr, i);
387  Builder.CreateStore(actualArgs[i], ptr);
388  }
389  args.push_back(doublePtr);
390  return args;
391  }
392 
393  AllocaInst *arrayPtr = createArray(Builder, ArrayType::get(Type::getDoubleTy(llvmContext), 3), numArgs);
394  for (unsigned i = 0; i < numArgs; ++i) {
395  LLVM_VALUE toInsert = actualArgs[i];
396  LLVM_VALUE subArrayPtr = Builder.CreateConstGEP2_32(nullptr, arrayPtr, 0, i);
397  for (unsigned j = 0; j < 3; ++j) {
398  LLVM_VALUE destAddr = Builder.CreateConstGEP2_32(nullptr, subArrayPtr, 0, j);
399  LLVM_VALUE srcAddr = CREATE_CONST_GEP1_32(Builder, toInsert, j);
400  Builder.CreateStore(CREATE_LOAD(Builder, srcAddr), destAddr);
401  }
402  }
403  args.push_back(Builder.CreateBitCast(arrayPtr, Type::getDoublePtrTy(llvmContext)));
404  return args;
405 }
406 
407 LLVM_VALUE executeStandardFunction(LLVM_BUILDER Builder, ExprFuncStandard::FuncType seFuncType, std::vector<LLVM_VALUE> args, LLVM_VALUE addrVal)
408 {
409  LLVMContext &llvmContext = Builder.getContext();
410 
411  args = promoteArgs(args, Builder, seFuncType);
412  args = replaceVecArgWithDoublePointer(Builder, args);
413 
414  if (isVarArg(seFuncType))
415  args = convertArgsToPointerAndLength(Builder, args, seFuncType);
416 
417  if (isReturnVector(seFuncType) == false)
418  return CreateCall(Builder, addrVal, args);
419 
420  // TODO: assume standard function all use vector of length 3 as parameter
421  // or return type.
422  AllocaInst *retPtr = createAllocaInst(Builder, Type::getDoubleTy(llvmContext), 3);
423  args.insert(args.begin(), retPtr);
424  CreateCall(Builder, addrVal, replaceVecArgWithDoublePointer(Builder, args));
425  return createVecValFromAlloca(Builder, retPtr, 3);
426 }
427 
428 // TODO: Is this necessary? why not use printf custom function?
429 LLVM_VALUE callPrintf(const ExprFuncNode *seFunc, LLVM_BUILDER Builder, Function *callee)
430 {
431  LLVMContext &llvmContext = Builder.getContext();
432  std::vector<LLVM_VALUE> args;
433 
434  // TODO: promotion for printf?
435  { // preprocess format string.
436  const auto *formatStrNode = dynamic_cast<const ExprStrNode *>(seFunc->child(0));
437  assert(formatStrNode);
438  std::string formatStr(formatStrNode->str());
439  std::string::size_type pos = std::string::npos;
440  while ((pos = formatStr.find("%v")) != std::string::npos)
441  formatStr.replace(pos, 2, std::string("[%f,%f,%f]"));
442  formatStr.append("\n");
443  args.push_back(Builder.CreateGlobalStringPtr(formatStr));
444  }
445 
446  for (int i = 1; i < seFunc->numChildren(); ++i) {
447  LLVM_VALUE arg = seFunc->child(i)->codegen(Builder);
448  if (arg->getType()->isVectorTy()) {
449  AllocaInst *vecArray = storeVectorToDoublePtr(Builder, arg);
450  for (unsigned i = 0; i < getVectorNumElements(arg->getType()); ++i) {
451  LLVM_VALUE elemPtr = CREATE_CONST_GEP1_32(Builder, vecArray, i);
452  args.push_back(CREATE_LOAD(Builder, elemPtr));
453  }
454  } else
455  args.push_back(arg);
456  }
457 
458  CreateCall(Builder, callee, args);
459  return ConstantFP::get(Type::getDoubleTy(llvmContext), 0.0);
460 }
461 
462 // TODO: not good. need better implementation.
463 LLVM_VALUE callCustomFunction(const ExprFuncNode *funcNode, LLVM_BUILDER Builder)
464 {
465  LLVMContext &llvmContext = Builder.getContext();
466 
467  // get the function's arguments
468  std::vector<LLVM_VALUE> args = codegenFuncCallArgs(Builder, funcNode);
469  int nargs = funcNode->numChildren();
470  assert(nargs == (int)args.size());
471 
472  // get the number of items that the function returns
473  auto sizeOfRet = (unsigned)funcNode->type().dim();
474  assert(sizeOfRet == 1 || funcNode->type().isFP());
475 
476  // TODO: is this necessary ? Doesn't seem to be used :/
477  createAllocaInst(Builder, Type::getDoubleTy(llvmContext), sizeOfRet);
478 
479  // calculate how much space for opData, fpArg and strArg
480  unsigned sizeOfFpArgs = 1 + sizeOfRet;
481  unsigned sizeOfStrArgs = 2;
482  for (int i = 0; i < nargs; ++i) {
483  ExprType argType = funcNode->child(i)->type();
484  if (argType.isFP()) {
485  sizeOfFpArgs += std::max(funcNode->promote(i), argType.dim());
486  } else if (argType.isString()) {
487  sizeOfStrArgs += 1;
488  } else {
489  assert(false && "invalid type encountered");
490  }
491  }
492 
493  // a few types that are reused throughout this function
494  Type *int32Ty = Type::getInt32Ty(llvmContext); // int
495  Type *doubleTy = Type::getDoubleTy(llvmContext); // double
496  PointerType *int8PtrTy = Type::getInt8PtrTy(llvmContext); // char*
497  Type *int64Ty = Type::getInt64Ty(llvmContext); // int64_t
498 
499  // allocate data that we will feed to KSeExprLLVMEvalCustomFunction on the stack
500  AllocaInst *opDataArg = createAllocaInst(Builder, int32Ty, (unsigned)nargs + 4, "opDataArgPtr");
501  AllocaInst *fpArg = createAllocaInst(Builder, doubleTy, sizeOfFpArgs, "fpArgPtr");
502  AllocaInst *strArg = createAllocaInst(Builder, int8PtrTy, sizeOfStrArgs, "strArgPtr");
503 
504  // fill fpArgPtr's first value
505  Builder.CreateStore(ConstantFP::get(doubleTy, nargs), fpArg);
506 
507  // fill opDataArgPtr
508  Builder.CreateStore(ConstantInt::get(int32Ty, 0), CREATE_CONST_GEP1_32(Builder, opDataArg, 0));
509  Builder.CreateStore(ConstantInt::get(int32Ty, 1), CREATE_CONST_GEP1_32(Builder, opDataArg, 1));
510  Builder.CreateStore(ConstantInt::get(int32Ty, 1), CREATE_CONST_GEP1_32(Builder, opDataArg, 2));
511  Builder.CreateStore(ConstantInt::get(int32Ty, 0), CREATE_CONST_GEP1_32(Builder, opDataArg, 3));
512 
513  // Load arguments into the pseudo interpreter data structure
514  unsigned fpIdx = 1 + sizeOfRet;
515  unsigned strIdx = 2;
516  for (int argIndex = 0; argIndex < nargs; ++argIndex) {
517  int opIndex = argIndex + 4;
518  ExprType argType = funcNode->child(argIndex)->type();
519  if (argType.isFP()) {
520  // store the fpArgPtr indirection index
521  Builder.CreateStore(ConstantInt::get(int32Ty, fpIdx), CREATE_CONST_GEP1_32(Builder, opDataArg, opIndex));
522  if (argType.dim() > 1) {
523  for (int comp = 0; comp < argType.dim(); comp++) {
524  LLVM_VALUE compIndex = ConstantInt::get(int32Ty, comp);
525  LLVM_VALUE val = Builder.CreateExtractElement(args[argIndex], compIndex);
526  LLVM_VALUE fpArgPtr = CREATE_CONST_GEP1_32(Builder, fpArg, fpIdx + comp);
527  Builder.CreateStore(val, fpArgPtr);
528  }
529  fpIdx += argType.dim();
530  } else {
531  // TODO: this needs the promote!!!
532  int promote = funcNode->promote(argIndex);
533  if (promote) {
534  LLVM_VALUE val = args[argIndex];
535  for (int comp = 0; comp < promote; comp++) {
536  LLVM_VALUE fpArgPtr = CREATE_CONST_GEP1_32(Builder, fpArg, fpIdx + comp);
537  Builder.CreateStore(val, fpArgPtr);
538  }
539  fpIdx += promote;
540  } else {
541  Builder.CreateStore(args[argIndex], CREATE_CONST_GEP1_32(Builder, fpArg, fpIdx));
542  fpIdx++;
543  }
544  }
545  } else if (argType.isString()) {
546  // store the strArgPtr indirection index
547  Builder.CreateStore(ConstantInt::get(int32Ty, strIdx), CREATE_CONST_GEP1_32(Builder, opDataArg, opIndex));
548  Builder.CreateStore(args[argIndex], CREATE_CONST_GEP1_32(Builder, strArg, strIdx));
549  strIdx++;
550  }
551  }
552 
553  // get the module from the builder
554  Module *module = llvm_getModule(Builder);
555 
556  // TODO: thread safety?
557  // TODO: This leaks!
558  auto *dataGV = new GlobalVariable(*module, int8PtrTy, false, GlobalValue::InternalLinkage, ConstantPointerNull::get(int8PtrTy));
559 
560  // call the function
561  Builder.CreateCall(module->getFunction("KSeExprLLVMEvalCustomFunction"), {opDataArg, fpArg, strArg, dataGV, ConstantInt::get(int64Ty, reinterpret_cast<uint64_t>(funcNode))});
562 
563  // read the result from memory
564  int resultOffset = 1;
565  if (funcNode->type().isFP()) {
566  if (sizeOfRet == 1) {
567  return CREATE_LOAD(Builder, CREATE_CONST_GEP1_32(Builder, fpArg, resultOffset));
568  } else if (sizeOfRet > 1) {
569  std::vector<LLVM_VALUE> resultArray;
570  for (unsigned int comp = 0; comp < sizeOfRet; comp++) {
571  LLVM_VALUE ptr = CREATE_CONST_GEP1_32(Builder, fpArg, resultOffset + comp); // skip nargs
572  resultArray.push_back(CREATE_LOAD(Builder, ptr));
573  }
574  return createVecVal(Builder, resultArray);
575  }
576  } else {
577  return CREATE_LOAD(Builder, CREATE_CONST_GEP1_32(Builder, strArg, 1));
578  }
579 
580  assert(false);
581  return nullptr;
582 }
583 } // namespace
584 
585 extern "C" void KSeExprLLVMEvalFPVarRef(ExprVarRef *seVR, double *result)
586 {
587  seVR->eval(result);
588 }
589 extern "C" void KSeExprLLVMEvalStrVarRef(ExprVarRef *seVR, char **result)
590 {
591  seVR->eval((const char **)result);
592 }
593 
594 namespace KSeExpr
595 {
596 LLVM_VALUE promoteToDim(LLVM_VALUE val, unsigned dim, LLVM_BUILDER Builder)
597 {
598  Type *srcTy = val->getType();
599  if (srcTy->isVectorTy() || dim <= 1)
600  return val;
601 
602  assert(srcTy->isDoubleTy());
603  return createVecVal(Builder, val, dim);
604 }
605 
606 LLVM_VALUE ExprNode::codegen(LLVM_BUILDER Builder) const
607 {
608  for (int i = 0; i < numChildren(); i++)
609  child(i)->codegen(Builder);
610  return nullptr;
611 }
612 
613 LLVM_VALUE ExprModuleNode::codegen(LLVM_BUILDER Builder) const
614 {
615  LLVM_VALUE lastVal = nullptr;
616  for (int i = 0; i < numChildren(); i++)
617  lastVal = child(i)->codegen(Builder);
618  assert(lastVal);
619  return lastVal;
620 }
621 
622 LLVM_VALUE ExprBlockNode::codegen(LLVM_BUILDER Builder) const
623 {
624  LLVM_VALUE lastVal = nullptr;
625  for (int i = 0; i < numChildren(); i++)
626  lastVal = child(i)->codegen(Builder);
627  assert(lastVal);
628  return lastVal;
629 }
630 
631 LLVM_VALUE ExprNumNode::codegen(LLVM_BUILDER Builder) const
632 {
633  return ConstantFP::get(Builder.getContext(), APFloat(_val));
634 }
635 
636 LLVM_VALUE ExprBinaryOpNode::codegen(LLVM_BUILDER Builder) const
637 {
638  LLVM_VALUE c1 = child(0)->codegen(Builder);
639  LLVM_VALUE c2 = child(1)->codegen(Builder);
640  std::pair<LLVM_VALUE, LLVM_VALUE> pv = promoteBinaryOperandsToAppropriateVector(Builder, c1, c2);
641  LLVM_VALUE op1 = pv.first;
642  LLVM_VALUE op2 = pv.second;
643 
644  const bool isString = child(0)->type().isString();
645 
646  if (isString == false) {
647  switch (_op) {
648  case '+':
649  return Builder.CreateFAdd(op1, op2);
650  case '-':
651  return Builder.CreateFSub(op1, op2);
652  case '*':
653  return Builder.CreateFMul(op1, op2);
654  case '/':
655  return Builder.CreateFDiv(op1, op2);
656  case '%': {
657  // niceMod() from v1: b==0 ? 0 : a-floor(a/b)*b
658  LLVM_VALUE a = op1;
659  LLVM_VALUE b = op2;
660  LLVM_VALUE aOverB = Builder.CreateFDiv(a, b);
661  Function *floorFun = Intrinsic::getDeclaration(llvm_getModule(Builder), Intrinsic::floor, op1->getType());
662  LLVM_VALUE normal = Builder.CreateFSub(a, Builder.CreateFMul(Builder.CreateCall(floorFun, {aOverB}), b));
663  Constant *zero = ConstantFP::get(op1->getType(), 0.0);
664  return Builder.CreateSelect(Builder.CreateFCmpOEQ(zero, op1), zero, normal);
665  }
666  case '^': {
667  // TODO: make external function reference work with interpreter, libffi
668  // TODO: needed for MCJIT??
669  // TODO: is the above not already done?!
670  std::vector<Type *> arg_type;
671  arg_type.push_back(op1->getType());
672  Function *fun = Intrinsic::getDeclaration(llvm_getModule(Builder), Intrinsic::pow, arg_type);
673  std::vector<LLVM_VALUE> ops = {op1, op2};
674  return Builder.CreateCall(fun, ops);
675  }
676  }
677  } else {
678  // precompute a few things
679  LLVMContext &context = Builder.getContext();
680  Module *module = llvm_getModule(Builder);
681  PointerType *i8PtrPtrTy = PointerType::getUnqual(Type::getInt8PtrTy(context));
682  Type *i32Ty = Type::getInt32Ty(context);
683  Function *strlen = module->getFunction("strlen");
684  Function *malloc = module->getFunction("malloc");
685  Function *free = module->getFunction("free");
686  Function *memset = module->getFunction("memset");
687  Function *strcat = module->getFunction("strcat");
688 
689  // do magic (see the pseudo C code on the comments at the end
690  // of each LLVM instruction)
691 
692  // compute the length of the operand strings
693  LLVM_VALUE len1 = Builder.CreateCall(strlen, {op1}); // len1 = strlen(op1);
694  LLVM_VALUE len2 = Builder.CreateCall(strlen, {op2}); // len2 = strlen(op2);
695  LLVM_VALUE len = Builder.CreateAdd(len1, len2); // len = len1 + len2;
696 
697  // allocate and clear memory
698  LLVM_VALUE alloc = Builder.CreateCall(malloc, {len}); // alloc = malloc(len1 + len2);
699  LLVM_VALUE zero = ConstantInt::get(i32Ty, 0); // zero = 0;
700  Builder.CreateCall(memset, {alloc, zero, len}); // memset(alloc, zero, len);
701 
702  // concatenate operand strings into output string
703  Builder.CreateCall(strcat, {alloc, op1}); // strcat(alloc, op1);
704  LLVM_VALUE newAlloc = Builder.CreateGEP(nullptr, alloc, len1); // newAlloc = alloc + len1
705  Builder.CreateCall(strcat, {newAlloc, op2}); // strcat(alloc, op2);
706 
707  // store the address in the node's _out member so that it will be
708  // cleaned up when the expression is destroyed.
709  APInt outAddr = APInt(64, reinterpret_cast<uint64_t>(&_out));
710  LLVM_VALUE out = Constant::getIntegerValue(i8PtrPtrTy, outAddr); // out = &_out;
711  Builder.CreateCall(free, {CREATE_LOAD(Builder, out)}); // free(*out);
712  Builder.CreateStore(alloc, out); // *out = alloc
713  return alloc;
714  }
715 
716  assert(false && "unexpected op");
717  return nullptr;
718 }
719 
720 // This is the def of def-use chain
721 // We don't go to VarNode::codegen. It is codegen'd here.
722 LLVM_VALUE ExprAssignNode::codegen(LLVM_BUILDER Builder) const
723 {
724  // codegen value to store
725  LLVM_VALUE val = child(0)->codegen(Builder);
726  // code gen pointer to store into
727  const std::string &varName = name();
728  LLVM_VALUE varPtr = _localVar->codegen(Builder, varName, val);
729  // do actual store
730  Builder.CreateStore(val, varPtr);
731  return nullptr;
732 }
733 
735 LLVM_VALUE ExprLocalVar::codegen(LLVM_BUILDER Builder, const std::string &varName, LLVM_VALUE refValue) const
736 {
737  _varPtr = createAllocaInst(Builder, refValue->getType(), 1, varName);
738  return _varPtr;
739 }
740 
741 LLVM_VALUE ExprCompareEqNode::codegen(LLVM_BUILDER Builder) const
742 {
743  LLVM_VALUE op1 = getFirstElement(child(0)->codegen(Builder), Builder);
744  LLVM_VALUE op2 = getFirstElement(child(1)->codegen(Builder), Builder);
745 
746  LLVM_VALUE boolVal = nullptr;
747 
748  const bool isString = child(0)->type().isString();
749 
750  if (isString == false) {
751  switch (_op) {
752  case '!':
753  boolVal = Builder.CreateFCmpONE(op1, op2);
754  break;
755  case '=':
756  boolVal = Builder.CreateFCmpOEQ(op1, op2);
757  break;
758  default:
759  assert(false && "Unkown CompareEq op.");
760  }
761  return Builder.CreateUIToFP(boolVal, op1->getType());
762  } else {
763  // precompute a few things
764  LLVMContext &llvmContext = Builder.getContext();
765  Module *module = llvm_getModule(Builder);
766  Type *doubleTy = Type::getDoubleTy(llvmContext);
767  Function *strcmp = module->getFunction("strcmp");
768 
769  LLVM_VALUE val = Builder.CreateCall(strcmp, {op1, op2}); // val = strcmp(op1, op2);
770  Constant *zero = ConstantInt::get(strcmp->getReturnType(), 0);
771 
772  switch (_op) {
773  case '!':
774  boolVal = Builder.CreateICmpNE(val, zero); // boolVal = val == 0;
775  break;
776  case '=':
777  boolVal = Builder.CreateICmpEQ(val, zero); // boolVal = val != 0;
778  break;
779  default:
780  assert(false && "Unkown CompareEq op.");
781  }
782  return Builder.CreateUIToFP(boolVal, doubleTy);
783  }
784 }
785 
786 LLVM_VALUE ExprCompareNode::codegen(LLVM_BUILDER Builder) const
787 {
788  if (_op == '&' || _op == '|') {
789  // Handle & and | specially as conditionals to handle short circuiting!
790  LLVMContext &llvmContext = Builder.getContext();
791 
792  LLVM_VALUE op1 = getFirstElement(child(0)->codegen(Builder), Builder);
793  Type *opTy = op1->getType();
794  Constant *zero = ConstantFP::get(opTy, 0.0);
795 
796  LLVM_VALUE op1IsOne = Builder.CreateFCmpUNE(op1, zero);
797 
798  Function *F = llvm_getFunction(Builder);
799  BasicBlock *thenBlock = BasicBlock::Create(llvmContext, "then", F);
800  BasicBlock *elseBlock = BasicBlock::Create(llvmContext, "else", F);
801  BasicBlock *phiBlock = BasicBlock::Create(llvmContext, "phi", F);
802  Builder.CreateCondBr(op1IsOne, thenBlock, elseBlock);
803 
804  LLVM_VALUE op2IsOne = nullptr;
805  Type *intTy = Type::getInt1Ty(llvmContext);
806  Type *doubleTy = Type::getDoubleTy(llvmContext);
807  llvm::PHINode *phiNode = nullptr;
808  if (_op == '&') {
809  // TODO: full IfThenElsenot needed
810  Builder.SetInsertPoint(thenBlock);
811  LLVM_VALUE op2 = child(1)->codegen(Builder);
812  op2IsOne = Builder.CreateFCmpUNE(op2, zero);
813  Builder.CreateBr(phiBlock);
814  thenBlock = Builder.GetInsertBlock();
815 
816  Builder.SetInsertPoint(elseBlock);
817  Builder.CreateBr(phiBlock);
818  Builder.SetInsertPoint(phiBlock);
819 
820  phiNode = Builder.CreatePHI(intTy, 2, "iftmp");
821  phiNode->addIncoming(op2IsOne, thenBlock);
822  phiNode->addIncoming(op1IsOne, elseBlock);
823  } else if (_op == '|') {
824  // TODO: full IfThenElsenot needed
825  Builder.SetInsertPoint(thenBlock);
826  Builder.CreateBr(phiBlock);
827 
828  Builder.SetInsertPoint(elseBlock);
829  LLVM_VALUE op2 = child(1)->codegen(Builder);
830  op2IsOne = Builder.CreateFCmpUNE(op2, zero);
831  Builder.CreateBr(phiBlock);
832  elseBlock = Builder.GetInsertBlock();
833 
834  Builder.SetInsertPoint(phiBlock);
835  phiNode = Builder.CreatePHI(intTy, 2, "iftmp");
836  phiNode->addIncoming(op1IsOne, thenBlock);
837  phiNode->addIncoming(op2IsOne, elseBlock);
838  } else {
839  throw std::runtime_error("Logical inconsistency.");
840  }
841  LLVM_VALUE out = Builder.CreateUIToFP(phiNode, doubleTy);
842  return out;
843  } else {
844  LLVM_VALUE op1 = getFirstElement(child(0)->codegen(Builder), Builder);
845  LLVM_VALUE op2 = getFirstElement(child(1)->codegen(Builder), Builder);
846 
847  Type *opTy = op1->getType();
848  Constant *zero = ConstantFP::get(opTy, 0.0);
849  LLVM_VALUE boolVal = nullptr;
850 
851  switch (_op) {
852  case '|': {
853  LLVM_VALUE op1IsOne = Builder.CreateFCmpUNE(op1, zero);
854  LLVM_VALUE op2IsOne = Builder.CreateFCmpUNE(op2, zero);
855  boolVal = Builder.CreateOr(op1IsOne, op2IsOne);
856  break;
857  }
858  case '&': {
859  assert(false); // handled above
860  break;
861  }
862  case 'g':
863  boolVal = Builder.CreateFCmpOGE(op1, op2);
864  break;
865  case 'l':
866  boolVal = Builder.CreateFCmpOLE(op1, op2);
867  break;
868  case '>':
869  boolVal = Builder.CreateFCmpOGT(op1, op2);
870  break;
871  case '<':
872  boolVal = Builder.CreateFCmpOLT(op1, op2);
873  break;
874  default:
875  assert(false && "Unkown Compare op.");
876  }
877 
878  return Builder.CreateUIToFP(boolVal, opTy);
879  }
880 }
881 
882 LLVM_VALUE ExprCondNode::codegen(LLVM_BUILDER Builder) const
883 {
884 #if 0 // old non-short circuit
885  LLVM_VALUE condVal = getFirstElement(child(0)->codegen(Builder), Builder);
886  LLVM_VALUE cond = Builder.CreateFCmpUNE(condVal,
887  ConstantFP::get(condVal->getType(), 0.0));
888  LLVM_VALUE trueVal = child(1)->codegen(Builder);
889  LLVM_VALUE falseVal = child(2)->codegen(Builder);
890  std::pair<LLVM_VALUE, LLVM_VALUE> pv = promoteBinaryOperandsToAppropriateVector(Builder, trueVal, falseVal);
891  return Builder.CreateSelect(cond, pv.first, pv.second);
892 #else // new short circuit version
893  LLVM_VALUE condVal = getFirstElement(child(0)->codegen(Builder), Builder);
894  LLVM_VALUE condAsBool = Builder.CreateFCmpUNE(condVal, ConstantFP::get(condVal->getType(), 0.0));
895  LLVMContext &llvmContext = Builder.getContext();
896  Function *F = llvm_getFunction(Builder);
897  BasicBlock *thenBlock = BasicBlock::Create(llvmContext, "then", F);
898  BasicBlock *elseBlock = BasicBlock::Create(llvmContext, "else", F);
899  BasicBlock *phiBlock = BasicBlock::Create(llvmContext, "phi", F);
900  Builder.CreateCondBr(condAsBool, thenBlock, elseBlock);
901 
902  Builder.SetInsertPoint(thenBlock);
903  LLVM_VALUE trueVal = promoteOperand(Builder, _type, child(1)->codegen(Builder));
904  Builder.CreateBr(phiBlock);
905  thenBlock = Builder.GetInsertBlock();
906 
907  Builder.SetInsertPoint(elseBlock);
908  LLVM_VALUE falseVal = promoteOperand(Builder, _type, child(2)->codegen(Builder));
909  Builder.CreateBr(phiBlock);
910  elseBlock = Builder.GetInsertBlock();
911 
912  Builder.SetInsertPoint(phiBlock);
913  llvm::PHINode *phiNode = Builder.CreatePHI(trueVal->getType(), 2, "iftmp");
914  phiNode->addIncoming(trueVal, thenBlock);
915  phiNode->addIncoming(falseVal, elseBlock);
916  return phiNode;
917 
918 #endif
919 }
920 
921 LLVM_VALUE ExprFuncNode::codegen(LLVM_BUILDER Builder) const
922 {
923  LLVMContext &llvmContext = Builder.getContext();
924  Module *M = llvm_getModule(Builder);
925  std::string calleeName(name());
926 
927  /************* call local function or printf *************/
928  Function *callee = M->getFunction(calleeName);
929  if (calleeName == "printf") {
930  if (!callee) {
931  FunctionType *FT = FunctionType::get(Type::getVoidTy(llvmContext), Type::getInt8PtrTy(llvmContext), true);
932  callee = Function::Create(FT, GlobalValue::ExternalLinkage, "printf", llvm_getModule(Builder));
933  }
934  return callPrintf(this, Builder, callee);
935  } else if (callee) {
936  std::vector<LLVM_VALUE> args = promoteArgs(codegenFuncCallArgs(Builder, this), Builder, callee->getFunctionType());
937  return Builder.CreateCall(callee, args);
938  }
939 
940  /************* call standard function or custom function *************/
941  // call custom function
942  const auto *standfunc = dynamic_cast<const ExprFuncStandard *>(_func->funcx());
943  if (!standfunc)
944  return callCustomFunction(this, Builder);
945 
946  // call standard function
947  // get function pointer
948  ExprFuncStandard::FuncType seFuncType = standfunc->getFuncType();
949  FunctionType *llvmFuncType = getSeExprFuncStandardLLVMType(seFuncType, llvmContext);
950  void *fp = standfunc->getFuncPointer();
951  ConstantInt *funcAddr = ConstantInt::get(Type::getInt64Ty(llvmContext), reinterpret_cast<uint64_t>(fp));
952  LLVM_VALUE addrVal = Builder.CreateIntToPtr(funcAddr, PointerType::getUnqual(llvmFuncType));
953 
954  // Collect distribution positions
955  std::vector<LLVM_VALUE> args = codegenFuncCallArgs(Builder, this);
956  std::vector<int> argumentIsVectorAndNeedsDistribution(args.size(), 0);
957  Type *maxVectorArgType = nullptr;
958  if (seFuncType == ExprFuncStandard::FUNCN) {
959  for (unsigned i = 0; i < args.size(); ++i) {
960  if (args[i]->getType()->isVectorTy()) {
961  maxVectorArgType = args[i]->getType();
962  argumentIsVectorAndNeedsDistribution[i] = 1;
963  }
964  }
965  } else if (seFuncType == ExprFuncStandard::FUNCNV || seFuncType == ExprFuncStandard::FUNCNVV) {
966  } else {
967  unsigned shift = isReturnVector(seFuncType) ? 1 : 0;
968  for (unsigned i = 0; i < args.size(); ++i) {
969  Type *paramType = llvmFuncType->getParamType(i + shift);
970  Type *argType = args[i]->getType();
971  if (argType->isVectorTy() && paramType->isDoubleTy()) {
972  maxVectorArgType = args[i]->getType();
973  argumentIsVectorAndNeedsDistribution[i] = 1;
974  }
975  }
976  }
977 
978  if (!maxVectorArgType) // nothing needs distribution so just execute normally
979  return executeStandardFunction(Builder, seFuncType, args, addrVal);
980 
981  assert(maxVectorArgType->isVectorTy());
982 
983  std::vector<LLVM_VALUE> ret;
984  for (unsigned vecComponent = 0; vecComponent < getVectorNumElements(maxVectorArgType); ++vecComponent) {
985  LLVM_VALUE idx = ConstantInt::get(Type::getInt32Ty(llvmContext), vecComponent);
986  std::vector<LLVM_VALUE> realArgs;
987  // Break the function into multiple calls per component of the output
988  // i.e. sin([1,2,3]) should be [sin(1),sin(2),sin(3)]
989  for (unsigned argIndex = 0; argIndex < args.size(); ++argIndex) {
990  LLVM_VALUE realArg = args[argIndex];
991  if (argumentIsVectorAndNeedsDistribution[argIndex]) {
992  if (args[argIndex]->getType()->isPointerTy())
993  realArg = CREATE_LOAD(Builder, Builder.CreateConstGEP2_32(nullptr, args[argIndex], 0, vecComponent));
994  else
995  realArg = Builder.CreateExtractElement(args[argIndex], idx);
996  }
997  realArgs.push_back(realArg);
998  }
999  ret.push_back(executeStandardFunction(Builder, seFuncType, realArgs, addrVal));
1000  }
1001  return createVecVal(Builder, ret);
1002 }
1003 
1004 LLVM_VALUE ExprIfThenElseNode::codegen(LLVM_BUILDER Builder) const
1005 {
1006  LLVM_VALUE condVal = getFirstElement(child(0)->codegen(Builder), Builder);
1007  Type *condTy = condVal->getType();
1008 
1009  LLVMContext &llvmContext = Builder.getContext();
1010 
1011  Constant *zero = ConstantFP::get(condTy, 0.0);
1012  LLVM_VALUE intCond = Builder.CreateFCmpUNE(condVal, zero);
1013 
1014  Function *F = llvm_getFunction(Builder);
1015  BasicBlock *thenBlock = BasicBlock::Create(llvmContext, "then", F);
1016  BasicBlock *elseBlock = BasicBlock::Create(llvmContext, "else", F);
1017  BasicBlock *phiBlock = BasicBlock::Create(llvmContext, "phi", F);
1018  Builder.CreateCondBr(intCond, thenBlock, elseBlock);
1019 
1020  Builder.SetInsertPoint(thenBlock);
1021  child(1)->codegen(Builder);
1022  thenBlock = Builder.GetInsertBlock();
1023 
1024  Builder.SetInsertPoint(elseBlock);
1025  child(2)->codegen(Builder);
1026  elseBlock = Builder.GetInsertBlock();
1027 
1028  // make all the merged variables. in the if then basic blocks
1029  // this is because we need phi ops to be alone
1030  Builder.SetInsertPoint(phiBlock);
1031  const auto &merges = _varEnv->merge(_varEnvMergeIndex);
1032  std::vector<LLVM_VALUE> phis;
1033  phis.reserve(merges.size());
1034  for (const auto &it : merges) {
1035  ExprLocalVarPhi *finalVar = it.second;
1036  if (finalVar->valid()) {
1037  ExprType refType = finalVar->type();
1038  Builder.SetInsertPoint(thenBlock);
1039  LLVM_VALUE thenValue = promoteOperand(Builder, refType, CREATE_LOAD(Builder, finalVar->_thenVar->varPtr()));
1040  Builder.SetInsertPoint(elseBlock);
1041  LLVM_VALUE elseValue = promoteOperand(Builder, refType, CREATE_LOAD(Builder, finalVar->_elseVar->varPtr()));
1042 
1043  Type *finalType = thenValue->getType();
1044  Builder.SetInsertPoint(phiBlock);
1045  PHINode *phi = Builder.CreatePHI(finalType, 2, it.first);
1046  phi->addIncoming(thenValue, thenBlock);
1047  phi->addIncoming(elseValue, elseBlock);
1048  phis.push_back(phi);
1049  }
1050  }
1051  // Now that we made all of the phi blocks, we must store them into the variables
1052  int idx = 0;
1053  for (auto &it : _varEnv->merge(_varEnvMergeIndex)) {
1054  const std::string &name = it.first;
1055  ExprLocalVarPhi *finalVar = it.second;
1056  if (finalVar->valid()) {
1057  LLVM_VALUE _finalVarPtr = finalVar->codegen(Builder, name + "-merge", phis[idx]);
1058  Builder.CreateStore(phis[idx++], _finalVarPtr);
1059  }
1060  }
1061  // Insert the ending jumps out of the then, else basic blocks
1062  Builder.SetInsertPoint(thenBlock);
1063  Builder.CreateBr(phiBlock);
1064  Builder.SetInsertPoint(elseBlock);
1065  Builder.CreateBr(phiBlock);
1066  // insert at end again
1067  Builder.SetInsertPoint(phiBlock);
1068 
1069  return nullptr;
1070 }
1071 
1072 LLVM_VALUE ExprLocalFunctionNode::codegen(LLVM_BUILDER Builder) const
1073 {
1074  IRBuilder<>::InsertPoint oldIP = Builder.saveIP();
1075  LLVMContext &llvmContext = Builder.getContext();
1076 
1077  // codegen prototype
1078  auto *F = cast<Function>(child(0)->codegen(Builder));
1079 
1080  // create alloca for args
1081  BasicBlock *BB = BasicBlock::Create(llvmContext, "entry", F);
1082  Builder.SetInsertPoint(BB);
1083  for (auto & AI : F->args()) {
1084  AllocaInst *Alloca = createAllocaInst(Builder, AI.getType(), 1, AI.getName());
1085  Alloca->takeName(&AI);
1086  Builder.CreateStore(&AI, Alloca);
1087  }
1088 
1089  LLVM_VALUE result = nullptr;
1090  for (int i = 1; i < numChildren(); i++)
1091  result = child(i)->codegen(Builder);
1092 
1093  Builder.CreateRet(result);
1094  Builder.restoreIP(oldIP);
1095  return nullptr;
1096 }
1097 
1098 LLVM_VALUE ExprPrototypeNode::codegen(LLVM_BUILDER Builder) const
1099 {
1100  LLVMContext &llvmContext = Builder.getContext();
1101 
1102  // get arg type
1103  std::vector<Type *> ParamTys;
1104  ParamTys.reserve(numChildren());
1105  for (int i = 0; i < numChildren(); ++i)
1106  ParamTys.push_back(createLLVMTyForSeExprType(llvmContext, argType(i)));
1107  // get ret type
1108  Type *retTy = createLLVMTyForSeExprType(llvmContext, returnType());
1109 
1110  FunctionType *FT = FunctionType::get(retTy, ParamTys, false);
1111  Function *F = Function::Create(FT, GlobalValue::InternalLinkage, name(), llvm_getModule(Builder));
1112 
1113  // Set names for all arguments.
1114  auto *AI = F->arg_begin();
1115  for (int i = 0, e = numChildren(); i != e; ++i, ++AI) {
1116  const auto *childNode = dynamic_cast<const ExprVarNode *>(child(i));
1117  assert(childNode);
1118  AI->setName(childNode->name());
1119  }
1120 
1121  return F;
1122 }
1123 
1124 LLVM_VALUE ExprStrNode::codegen(LLVM_BUILDER Builder) const
1125 {
1126  return Builder.CreateGlobalStringPtr(unescapeString(_str));
1127 }
1128 
1129 LLVM_VALUE ExprSubscriptNode::codegen(LLVM_BUILDER Builder) const
1130 {
1131  LLVM_VALUE op1 = child(0)->codegen(Builder);
1132  LLVM_VALUE op2 = child(1)->codegen(Builder);
1133 
1134  if (op1->getType()->isDoubleTy())
1135  return op1;
1136 
1137  LLVMContext &llvmContext = Builder.getContext();
1138  LLVM_VALUE idx = Builder.CreateFPToUI(op2, Type::getInt32Ty(llvmContext));
1139  return Builder.CreateExtractElement(op1, idx);
1140 }
1141 
1142 LLVM_VALUE ExprUnaryOpNode::codegen(LLVM_BUILDER Builder) const
1143 {
1144  LLVM_VALUE op1 = child(0)->codegen(Builder);
1145  Type *op1Ty = op1->getType();
1146  Constant *negateZero = ConstantFP::getZeroValueForNegation(op1Ty);
1147  Constant *zero = ConstantFP::get(op1Ty, 0.0);
1148  Constant *one = ConstantFP::get(op1Ty, 1.0);
1149 
1150  switch (_op) {
1151  case '-':
1152  return Builder.CreateFSub(negateZero, op1);
1153  case '~': {
1154  LLVM_VALUE neg = Builder.CreateFSub(negateZero, op1);
1155  return Builder.CreateFAdd(neg, one);
1156  }
1157  case '!': {
1158  LLVM_VALUE eqZero = Builder.CreateFCmpOEQ(zero, op1);
1159  return Builder.CreateSelect(eqZero, one, zero);
1160  }
1161  }
1162 
1163  assert(false && "not implemented.");
1164  return nullptr;
1165 }
1166 
1168 struct VarCodeGeneration {
1169  static LLVM_VALUE codegen(ExprVarRef *varRef, const std::string &varName, LLVM_BUILDER Builder)
1170  {
1171  LLVMContext &llvmContext = Builder.getContext();
1172 
1173  // a few types
1174  Type *int64Ty = Type::getInt64Ty(llvmContext); // int64_t
1175  Type *doubleTy = Type::getDoubleTy(llvmContext); // double
1176  PointerType *int8PtrTy = Type::getInt8PtrTy(llvmContext); // char *
1177 
1178  // get var informations
1179  bool isDouble = varRef->type().isFP();
1180  int dim = varRef->type().dim();
1181 
1182  // create the return value on the stack
1183  AllocaInst *returnValue = createAllocaInst(Builder, isDouble ? doubleTy : int8PtrTy, dim);
1184 
1185  // get our eval var function, and call it with a pointer to our var ref and a ref to the return value
1186  Function *evalVarFunc = llvm_getModule(Builder)->getFunction(isDouble == true ? "KSeExprLLVMEvalFPVarRef" : "KSeExprLLVMEvalStrVarRef");
1187  Builder.CreateCall(evalVarFunc, {Builder.CreateIntToPtr(ConstantInt::get(int64Ty, reinterpret_cast<uint64_t>(varRef)), int8PtrTy), returnValue});
1188 
1189  // load our return value
1190  LLVM_VALUE ret = 0;
1191  if (dim == 1) {
1192  ret = CREATE_LOAD(Builder, returnValue);
1193  } else {
1194  ret = createVecValFromAlloca(Builder, returnValue, dim);
1195  }
1196 
1197  AllocaInst *thisvar = createAllocaInst(Builder, ret->getType(), 1, varName);
1198  Builder.CreateStore(ret, thisvar);
1199  return ret;
1200  }
1201 
1202  static LLVM_VALUE codegen(VarBlockCreator::Ref *varRef, const std::string &varName, LLVM_BUILDER Builder)
1203  {
1204  LLVMContext &llvmContext = Builder.getContext();
1205 
1206  int variableOffset = varRef->offset();
1207  int variableStride = varRef->stride();
1208  Function *function = llvm_getFunction(Builder);
1209  auto *argIterator = function->arg_begin();
1210  argIterator++; // skip first arg
1211  llvm::Argument *variableBlock = &*(argIterator++);
1212  llvm::Argument *indirectIndex = &*(argIterator++);
1213 
1214  int dim = varRef->type().dim();
1215 
1216  Type *ptrToPtrTy = variableBlock->getType();
1217  Value *variableBlockAsPtrPtr = Builder.CreatePointerCast(variableBlock, ptrToPtrTy);
1218  Value *variableOffsetIndex = ConstantInt::get(Type::getInt32Ty(llvmContext), variableOffset);
1219  Value *variableBlockIndirectPtrPtr = IN_BOUNDS_GEP(Builder, variableBlockAsPtrPtr, variableOffsetIndex);
1220  Value *baseMemory = CREATE_LOAD(Builder, variableBlockIndirectPtrPtr);
1221  Value *variableStrideValue = ConstantInt::get(Type::getInt32Ty(llvmContext), variableStride);
1222  if (dim == 1) {
1224  Value *variablePointer = varRef->type().isLifetimeUniform() ? baseMemory : IN_BOUNDS_GEP(Builder, baseMemory, indirectIndex);
1225  return CREATE_LOAD(Builder, variablePointer);
1226  } else {
1227  std::vector<Value *> loadedValues(dim);
1228  for (int component = 0; component < dim; component++) {
1229  Value *componentIndex = ConstantInt::get(Type::getInt32Ty(llvmContext), component);
1231  Value *variablePointer = varRef->type().isLifetimeUniform() ? Builder.CreateInBoundsGEP(Type::getDoubleTy(llvmContext), baseMemory, componentIndex)
1232  : Builder.CreateInBoundsGEP(Type::getDoubleTy(llvmContext), baseMemory, Builder.CreateAdd(Builder.CreateMul(indirectIndex, variableStrideValue), componentIndex));
1233  loadedValues[component] = CREATE_LOAD_WITH_ID(Builder, variablePointer, varName);
1234  }
1235  return createVecVal(Builder, loadedValues, varName);
1236  }
1237  }
1238 };
1239 
1240 // This is the use of def-use chain
1241 LLVM_VALUE ExprVarNode::codegen(LLVM_BUILDER Builder) const
1242 {
1243  if (_var) {
1244  // All external var has the prefix "external_" in current function to avoid
1245  // potential name conflict with local variable
1246  std::string varName("external_");
1247  varName.append(name());
1248  // if (LLVM_VALUE valPtr = resolveLocalVar(varName.c_str(), Builder))
1249  // return CREATE_LOAD(Builder, valPtr);
1250  if (auto *varBlockRef = dynamic_cast<VarBlockCreator::Ref *>(_var))
1251  return VarCodeGeneration::codegen(varBlockRef, varName, Builder);
1252  else
1253  return VarCodeGeneration::codegen(_var, varName, Builder);
1254  } else if (_localVar) {
1255  ExprType varTy = _localVar->type();
1256  if (varTy.isFP() || varTy.isString()) {
1257  // LLVM_VALUE valPtr = resolveLocalVar(name(), Builder);
1258  LLVM_VALUE varPtr = _localVar->varPtr();
1259  assert(varPtr && "can not found symbol?");
1260  return CREATE_LOAD(Builder, varPtr);
1261  }
1262  }
1263 
1264  assert(false);
1265  return nullptr;
1266 }
1267 
1268 LLVM_VALUE ExprVecNode::codegen(LLVM_BUILDER Builder) const
1269 {
1270  std::vector<LLVM_VALUE> elems;
1271  ConstantInt *zero = ConstantInt::get(Type::getInt32Ty(Builder.getContext()), 0);
1272  for (int i = 0; i < numChildren(); i++) {
1273  LLVM_VALUE val = child(i)->codegen(Builder);
1274  elems.push_back(val->getType()->isVectorTy() ? Builder.CreateExtractElement(val, zero) : val);
1275  }
1276  return createVecVal(Builder, elems);
1277 }
1278 } // namespace KSeExpr
1279 
1280 #endif
void KSeExprLLVMEvalFPVarRef(KSeExpr::ExprVarRef *seVR, double *result)
void KSeExprLLVMEvalStrVarRef(KSeExpr::ExprVarRef *seVR, double *result)
double LLVM_BUILDER
Definition: ExprLLVM.h:26
double LLVM_VALUE
Definition: ExprLLVM.h:25
std::string unescapeString(const std::string &string)
Definition: StringUtils.h:13
Node that calls a function.
Definition: ExprNode.h:654
int promote(int i) const
Definition: ExprNode.h:758
ExprLocalVar join (merge) references. Remembers which variables are possible assigners to this.
Definition: ExprEnv.h:84
ExprLocalVar * _thenVar
Definition: ExprEnv.h:115
bool valid() const
Definition: ExprEnv.h:102
ExprLocalVar * _elseVar
Definition: ExprEnv.h:115
virtual LLVM_VALUE codegen(LLVM_BUILDER, const std::string &, LLVM_VALUE) LLVM_BASE
LLVM value that has been allocated.
virtual LLVM_VALUE varPtr()
LLVM value that has been pre-done.
Definition: ExprEnv.h:72
ExprType type() const
returns type of the variable
Definition: ExprEnv.h:52
virtual LLVM_VALUE codegen(LLVM_BUILDER) LLVM_BASE
int numChildren() const
Number of children.
Definition: ExprNode.h:108
const ExprType & type() const
The type of the node.
Definition: ExprNode.h:150
const ExprNode * child(size_t i) const
Get 0 indexed child.
Definition: ExprNode.h:114
Node that stores a string.
Definition: ExprNode.h:632
bool isLifetimeUniform() const
Definition: ExprType.h:234
Type type() const
Definition: ExprType.h:176
int dim() const
Definition: ExprType.h:180
bool isString() const
Definition: ExprType.h:210
bool isFP() const
Direct is predicate checks.
Definition: ExprType.h:190
Node that references a variable.
Definition: ExprNode.h:572
abstract class for implementing variable references
Definition: Expression.h:36
virtual ExprType type() const
returns (current) type
Definition: Expression.h:50
virtual void eval(double *result)=0
returns this variable's value by setting result
Internally implemented var ref used by SeExpr.
Definition: VarBlock.h:94
uint32_t stride() const
Definition: VarBlock.h:103
uint32_t offset() const
Definition: VarBlock.h:99
KSeExpr_DEFAULT double_t floor(double_t val)
Definition: Utils.cpp:168
double max(double x, double y)
Definition: ExprBuiltins.h:74
const ExprStrNode * isString(const ExprNode *testee)
Definition: ExprPatterns.h:36