8 #include "ExprConfig.h"
19 #if defined(SEEXPR_ENABLE_LLVM)
28 template<
class T>
class LLVMEvaluationContext
31 using FunctionPtr = void (*)(T *,
char **, uint32_t);
32 using FunctionPtrMultiple = void (*)(
char **, uint32_t, uint32_t, uint32_t);
33 FunctionPtr functionPtr{
nullptr};
34 FunctionPtrMultiple functionPtrMultiple{
nullptr};
35 T *resultData{
nullptr};
38 LLVMEvaluationContext(
const LLVMEvaluationContext &) =
delete;
39 LLVMEvaluationContext &operator=(
const LLVMEvaluationContext &) =
delete;
40 LLVMEvaluationContext(LLVMEvaluationContext &&) noexcept = default;
41 LLVMEvaluationContext& operator=(LLVMEvaluationContext &&) noexcept = default;
42 ~LLVMEvaluationContext()
46 LLVMEvaluationContext() =
default;
48 void init(
void *fp,
void *fpLoop,
int dim)
51 functionPtr =
reinterpret_cast<FunctionPtr
>(fp);
52 functionPtrMultiple =
reinterpret_cast<FunctionPtrMultiple
>(fpLoop);
53 resultData =
new T[dim];
59 functionPtr =
nullptr;
62 const T *operator()(VarBlock *varBlock)
64 assert(functionPtr && resultData);
65 functionPtr(resultData, varBlock ? varBlock->data() :
nullptr, varBlock ? varBlock->indirectIndex : 0);
68 void operator()(VarBlock *varBlock,
size_t outputVarBlockOffset,
size_t rangeStart,
size_t rangeEnd)
70 assert(functionPtr && resultData);
71 functionPtrMultiple(varBlock ? varBlock->data() :
nullptr, outputVarBlockOffset, rangeStart, rangeEnd);
74 std::unique_ptr<LLVMEvaluationContext<double>> _llvmEvalFP;
75 std::unique_ptr<LLVMEvaluationContext<char *>> _llvmEvalStr;
77 std::unique_ptr<llvm::LLVMContext> _llvmContext;
78 std::unique_ptr<llvm::ExecutionEngine> TheExecutionEngine;
81 LLVMEvaluator() =
default;
83 const char *
evalStr(VarBlock *varBlock)
85 return *(*_llvmEvalStr)(varBlock);
87 const double *
evalFP(VarBlock *varBlock)
89 return (*_llvmEvalFP)(varBlock);
92 void evalMultiple(VarBlock *varBlock, uint32_t outputVarBlockOffset, uint32_t rangeStart, uint32_t rangeEnd)
94 return (*_llvmEvalFP)(varBlock, outputVarBlockOffset, rangeStart, rangeEnd);
102 bool prepLLVM(ExprNode *parseTree,
const ExprType &desiredReturnType)
104 using namespace llvm;
105 InitializeNativeTarget();
106 InitializeNativeTargetAsmPrinter();
107 InitializeNativeTargetAsmParser();
109 std::string uniqueName = getUniqueName();
112 _llvmContext = std::make_unique<LLVMContext>();
114 std::unique_ptr<Module> TheModule(
new Module(uniqueName +
"_module", *_llvmContext));
117 Type *i8PtrTy = Type::getInt8PtrTy(*_llvmContext);
118 PointerType *i8PtrPtrTy = PointerType::getUnqual(i8PtrTy);
119 PointerType *i8PtrPtrPtrTy = PointerType::getUnqual(i8PtrPtrTy);
120 Type *i32Ty = Type::getInt32Ty(*_llvmContext);
121 Type *i32PtrTy = Type::getInt32PtrTy(*_llvmContext);
122 Type *i64Ty = Type::getInt64Ty(*_llvmContext);
123 Type *doublePtrTy = Type::getDoublePtrTy(*_llvmContext);
124 PointerType *doublePtrPtrTy = PointerType::getUnqual(doublePtrTy);
125 Type *voidTy = Type::getVoidTy(*_llvmContext);
128 Function *KSeExprLLVMEvalCustomFunctionFunc =
nullptr;
129 Function *KSeExprLLVMEvalFPVarRefFunc =
nullptr;
130 Function *KSeExprLLVMEvalStrVarRefFunc =
nullptr;
131 Function *KSeExprLLVMEvalstrlenFunc =
nullptr;
132 Function *KSeExprLLVMEvalmallocFunc =
nullptr;
133 Function *KSeExprLLVMEvalfreeFunc =
nullptr;
134 Function *KSeExprLLVMEvalmemsetFunc =
nullptr;
135 Function *KSeExprLLVMEvalstrcatFunc =
nullptr;
136 Function *KSeExprLLVMEvalstrcmpFunc =
nullptr;
139 FunctionType *FT = FunctionType::get(voidTy, {i32PtrTy, doublePtrTy, i8PtrPtrTy, i8PtrPtrTy, i64Ty},
false);
140 KSeExprLLVMEvalCustomFunctionFunc = Function::Create(FT, GlobalValue::ExternalLinkage,
"KSeExprLLVMEvalCustomFunction", TheModule.get());
143 FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, doublePtrTy},
false);
144 KSeExprLLVMEvalFPVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage,
"KSeExprLLVMEvalFPVarRef", TheModule.get());
147 FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, i8PtrPtrTy},
false);
148 KSeExprLLVMEvalStrVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage,
"KSeExprLLVMEvalStrVarRef", TheModule.get());
151 FunctionType *FT = FunctionType::get(i32Ty, {i8PtrTy},
false);
152 KSeExprLLVMEvalstrlenFunc = Function::Create(FT, Function::ExternalLinkage,
"strlen", TheModule.get());
155 FunctionType *FT = FunctionType::get(i8PtrTy, {i32Ty},
false);
156 KSeExprLLVMEvalmallocFunc = Function::Create(FT, Function::ExternalLinkage,
"malloc", TheModule.get());
159 FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy},
false);
160 KSeExprLLVMEvalfreeFunc = Function::Create(FT, Function::ExternalLinkage,
"free", TheModule.get());
163 FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, i32Ty, i32Ty},
false);
164 KSeExprLLVMEvalmemsetFunc = Function::Create(FT, Function::ExternalLinkage,
"memset", TheModule.get());
167 FunctionType *FT = FunctionType::get(i8PtrTy, {i8PtrTy, i8PtrTy},
false);
168 KSeExprLLVMEvalstrcatFunc = Function::Create(FT, Function::ExternalLinkage,
"strcat", TheModule.get());
171 FunctionType *FT = FunctionType::get(i32Ty, {i8PtrTy, i8PtrTy},
false);
172 KSeExprLLVMEvalstrcmpFunc = Function::Create(FT, Function::ExternalLinkage,
"strcmp", TheModule.get());
177 bool desireFP = desiredReturnType.isFP();
178 std::array<Type *, 3> ParamTys = {desireFP ? doublePtrTy : i8PtrPtrTy, doublePtrPtrTy, i32Ty};
179 FunctionType *FT = FunctionType::get(voidTy, ParamTys,
false);
180 Function *F = Function::Create(FT, Function::ExternalLinkage, uniqueName +
"_func", TheModule.get());
181 #if LLVM_VERSION_MAJOR > 4
182 F->addAttribute(llvm::AttributeList::FunctionIndex, llvm::Attribute::AlwaysInline);
184 F->addAttribute(llvm::AttributeSet::FunctionIndex, llvm::Attribute::AlwaysInline);
188 std::array<const char *, 3> names = {
"outputPointer",
"dataBlock",
"indirectIndex"};
190 for (
auto &arg : F->args())
191 arg.setName(names[idx++]);
194 auto dimDesired = desiredReturnType.dim();
195 auto dimGenerated = parseTree->type().dim();
197 BasicBlock *BB = BasicBlock::Create(*_llvmContext,
"entry", F);
198 IRBuilder<> Builder(BB);
201 Value *lastVal = parseTree->codegen(Builder);
204 Value *firstArg = &*F->arg_begin();
206 Value *newLastVal = promoteToDim(lastVal, dimDesired, Builder);
207 if (newLastVal->getType()->isVectorTy()) {
210 assert(dimDesired >= 1 &&
"error. dim of FP is less than 1.");
212 assert(dimGenerated >= 1 &&
"error. dim of FP is less than 1.");
214 assert(dimGenerated == 1 || dimGenerated >= dimDesired &&
"error: unable to match between FP of differing dimensions");
216 auto *VT = llvm::cast<llvm::VectorType>(newLastVal->getType());
217 #if LLVM_VERSION_MAJOR >= 13
218 if (VT && VT->getElementCount().getKnownMinValue() >= dimDesired) {
220 if (VT && VT->getNumElements() >= dimDesired) {
222 for (
unsigned i = 0; i < dimDesired; ++i) {
223 Value *idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), i);
224 Value *val = Builder.CreateExtractElement(newLastVal, idx);
225 Value *ptr = IN_BOUNDS_GEP(Builder, firstArg, idx);
226 Builder.CreateStore(val, ptr);
229 for (
unsigned i = 0; i < dimDesired; ++i) {
230 Value *idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), i);
231 Value *original_idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), 0);
232 Value *val = Builder.CreateExtractElement(newLastVal, original_idx);
233 Value *ptr = IN_BOUNDS_GEP(Builder, firstArg, idx);
234 Builder.CreateStore(val, ptr);
238 if (dimGenerated > 1) {
239 Value *newLastVal = promoteToDim(lastVal, dimDesired, Builder);
241 auto *VT = llvm::cast<llvm::VectorType>(newLastVal->getType());
242 #if LLVM_VERSION_MAJOR >= 13
243 assert(VT && VT->getElementCount().getKnownMinValue() >= dimDesired);
245 assert(VT && VT->getNumElements() >= dimDesired);
248 for (
unsigned i = 0; i < dimDesired; ++i) {
249 Value *idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), i);
250 Value *val = Builder.CreateExtractElement(newLastVal, idx);
251 Value *ptr = IN_BOUNDS_GEP(Builder, firstArg, idx);
252 Builder.CreateStore(val, ptr);
254 }
else if (dimGenerated == 1) {
255 for (
unsigned i = 0; i < dimDesired; ++i) {
256 Value *ptr = Builder.CreateConstInBoundsGEP1_32(
nullptr, firstArg, i);
257 Builder.CreateStore(lastVal, ptr);
260 assert(
false &&
"error. dim of FP is less than 1.");
264 Builder.CreateStore(lastVal, firstArg);
267 Builder.CreateRetVoid();
271 FunctionType *FTLOOP = FunctionType::get(voidTy, {i8PtrTy, i32Ty, i32Ty, i32Ty},
false);
272 Function *FLOOP = Function::Create(FTLOOP, Function::ExternalLinkage, uniqueName +
"_loopfunc", TheModule.get());
275 std::array<const char *, 4> names = {
"dataBlock",
"outputVarBlockOffset",
"rangeStart",
"rangeEnd"};
277 for (
auto &arg : FLOOP->args()) {
278 arg.setName(names[idx++]);
283 Value *dimValue = ConstantInt::get(i32Ty, dimDesired);
284 Value *oneValue = ConstantInt::get(i32Ty, 1);
287 BasicBlock *entryBlock = BasicBlock::Create(*_llvmContext,
"entry", FLOOP);
288 BasicBlock *loopCmpBlock = BasicBlock::Create(*_llvmContext,
"loopCmp", FLOOP);
289 BasicBlock *loopRepeatBlock = BasicBlock::Create(*_llvmContext,
"loopRepeat", FLOOP);
290 BasicBlock *loopIncBlock = BasicBlock::Create(*_llvmContext,
"loopInc", FLOOP);
291 BasicBlock *loopEndBlock = BasicBlock::Create(*_llvmContext,
"loopEnd", FLOOP);
292 IRBuilder<> Builder(entryBlock);
293 Builder.SetInsertPoint(entryBlock);
296 Function::arg_iterator argIterator = FLOOP->arg_begin();
297 Value *varBlockCharPtrPtrArg = &*argIterator;
299 Value *outputVarBlockOffsetArg = &*argIterator;
301 Value *rangeStartArg = &*argIterator;
303 Value *rangeEndArg = &*argIterator;
307 Value *rangeStartVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue,
"rangeStartVar");
308 Value *rangeEndVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue,
"rangeEndVar");
309 Value *indexVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue,
"indexVar");
310 Value *outputVarBlockOffsetVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue,
"outputVarBlockOffsetVar");
311 Value *varBlockDoublePtrPtrVar = Builder.CreateAlloca(doublePtrPtrTy, oneValue,
"varBlockDoublePtrPtrVar");
312 Value *varBlockTPtrPtrVar = Builder.CreateAlloca(desireFP ==
true ? doublePtrPtrTy : i8PtrPtrPtrTy, oneValue,
"varBlockTPtrPtrVar");
315 Builder.CreateStore(Builder.CreatePointerCast(varBlockCharPtrPtrArg, doublePtrPtrTy,
"varBlockAsDoublePtrPtr"), varBlockDoublePtrPtrVar);
316 Builder.CreateStore(Builder.CreatePointerCast(varBlockCharPtrPtrArg, desireFP ? doublePtrPtrTy : i8PtrPtrPtrTy,
"varBlockAsTPtrPtr"), varBlockTPtrPtrVar);
317 Builder.CreateStore(rangeStartArg, rangeStartVar);
318 Builder.CreateStore(rangeEndArg, rangeEndVar);
319 Builder.CreateStore(outputVarBlockOffsetArg, outputVarBlockOffsetVar);
322 Value *outputBasePtrPtr = Builder.CreateGEP(
nullptr, CREATE_LOAD(Builder, varBlockTPtrPtrVar), outputVarBlockOffsetArg,
"outputBasePtrPtr");
323 Value *outputBasePtr = CREATE_LOAD_WITH_ID(Builder, outputBasePtrPtr,
"outputBasePtr");
324 Builder.CreateStore(CREATE_LOAD(Builder, rangeStartVar), indexVar);
326 Builder.CreateBr(loopCmpBlock);
327 Builder.SetInsertPoint(loopCmpBlock);
328 Value *cond = Builder.CreateICmpULT(CREATE_LOAD(Builder, indexVar), CREATE_LOAD(Builder, rangeEndVar));
329 Builder.CreateCondBr(cond, loopRepeatBlock, loopEndBlock);
331 Builder.SetInsertPoint(loopRepeatBlock);
332 Value *myOutputPtr = Builder.CreateGEP(
nullptr, outputBasePtr, Builder.CreateMul(dimValue, CREATE_LOAD(Builder, indexVar)));
333 Builder.CreateCall(F, {myOutputPtr, CREATE_LOAD(Builder, varBlockDoublePtrPtrVar), CREATE_LOAD(Builder, indexVar)});
335 Builder.CreateBr(loopIncBlock);
337 Builder.SetInsertPoint(loopIncBlock);
338 Builder.CreateStore(Builder.CreateAdd(CREATE_LOAD(Builder, indexVar), oneValue), indexVar);
339 Builder.CreateBr(loopCmpBlock);
341 Builder.SetInsertPoint(loopEndBlock);
342 Builder.CreateRetVoid();
347 std::cerr <<
"Pre verified LLVM byte code " << std::endl;
348 TheModule->print(llvm::errs(),
nullptr);
357 Module *altModule = TheModule.get();
359 TheExecutionEngine.reset(EngineBuilder(std::move(TheModule))
360 .setErrorStr(&ErrStr)
362 .setOptLevel(CodeGenOpt::Aggressive)
365 altModule->setDataLayout(TheExecutionEngine->getDataLayout());
368 TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalFPVarRefFunc,
370 TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalStrVarRefFunc,
372 TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalCustomFunctionFunc,
374 TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalstrlenFunc,
375 reinterpret_cast<void *
>(strlen));
376 TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalstrcatFunc,
377 reinterpret_cast<void *
>(strcat));
378 TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalstrcmpFunc,
379 reinterpret_cast<void *
>(strcmp));
380 TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalmemsetFunc,
381 reinterpret_cast<void *
>(memset));
382 TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalmallocFunc,
383 reinterpret_cast<void *
>(malloc));
384 TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalfreeFunc,
385 reinterpret_cast<void *
>(free));
388 std::string errorStr;
389 llvm::raw_string_ostream raw(errorStr);
390 if (llvm::verifyModule(*altModule, &raw)) {
396 llvm::PassManagerBuilder builder;
397 std::unique_ptr<llvm::legacy::PassManager> pm(
new llvm::legacy::PassManager);
398 std::unique_ptr<llvm::legacy::FunctionPassManager> fpm(
new llvm::legacy::FunctionPassManager(altModule));
399 builder.OptLevel = 3;
400 #if (LLVM_VERSION_MAJOR >= 4)
401 builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
403 builder.Inliner = llvm::createAlwaysInlinerPass();
405 builder.populateModulePassManager(*pm);
407 builder.populateFunctionPassManager(*fpm);
414 if (!TheExecutionEngine) {
415 std::cerr <<
"Could not create ExecutionEngine: " << ErrStr << std::endl;
419 TheExecutionEngine->finalizeObject();
420 void *fp = TheExecutionEngine->getPointerToFunction(F);
421 void *fpLoop = TheExecutionEngine->getPointerToFunction(FLOOP);
423 _llvmEvalFP = std::make_unique<LLVMEvaluationContext<double>>();
424 _llvmEvalFP->init(fp, fpLoop, dimDesired);
426 _llvmEvalStr = std::make_unique<LLVMEvaluationContext<char *>>();
427 _llvmEvalStr->init(fp, fpLoop, dimDesired);
432 std::cerr <<
"Pre verified LLVM byte code " << std::endl;
433 altModule->print(llvm::errs(),
nullptr);
440 std::string getUniqueName()
const
442 std::ostringstream o;
443 o << std::setbase(16) << reinterpret_cast<uintptr_t>(
this);
444 return (
"_" + o.str());
454 assert(
false &&
"LLVM is not enabled in build");
void KSeExprLLVMEvalFPVarRef(KSeExpr::ExprVarRef *seVR, double *result)
void KSeExprLLVMEvalStrVarRef(KSeExpr::ExprVarRef *seVR, double *result)
void KSeExprLLVMEvalCustomFunction(int *opDataArg, double *fpArg, char **strArg, void **funcdata, const KSeExpr::ExprFuncNode *node)
Node that calls a function.
abstract class for implementing variable references
static bool debugging
Whether to debug expressions.
static bool prepLLVM(ExprNode *, ExprType)
static const char * evalStr(VarBlock *)
static const double * evalFP(VarBlock *)
static void evalMultiple(VarBlock *, int, size_t, size_t)
static void unsupported()
A thread local evaluation context. Just allocate and fill in with data.
@ Unknown
Unknown error (message = %1)