// // Expr.hpp // MNN // // Created by MNN on 2019/06/10. // Copyright © 2018, Alibaba Group Holding Limited // #ifndef Expr_hpp #define Expr_hpp #include #include #include #include #include #include #include #include namespace MNN { struct OpT; struct Op; struct NetT; namespace Express { class Variable; class Expr; class Executor; typedef std::shared_ptr EXPRP; typedef std::weak_ptr WeakEXPRP; typedef std::vector INTS; enum Dimensionformat { NHWC, NC4HW4, NCHW }; class MNN_PUBLIC VARP { public: VARP() { // Do nothing } VARP(std::shared_ptr c) { mContent = std::move(c); } VARP(Variable* c) { mContent.reset(c); } Variable* get() const { return mContent.get(); } ~ VARP() { // Do nothing } VARP(const VARP& var) { mContent = var.mContent; } VARP(VARP&& var) { mContent = std::move(var.mContent); } VARP operator+(VARP var) const; VARP operator-(VARP var) const; VARP operator*(VARP var) const; VARP operator/(VARP var) const; VARP mean(INTS dims) const; VARP sum(INTS dims) const; bool operator==(const VARP& var) const { return var.mContent == mContent; } bool operator<(const VARP& var) const { return mContent < var.mContent; } bool operator<=(const VARP& var) const { return mContent <= var.mContent; } VARP& operator=(const VARP& var) { mContent = var.mContent; return *this; } VARP& operator=(Variable* var) { mContent.reset(var); return *this; } Variable* operator->() const { return mContent.get(); } enum InputType { INPUT = 0, CONSTANT = 1, TRAINABLE = 2, }; bool fix(InputType type) const; private: friend class Variable; std::shared_ptr mContent; }; inline bool operator==(Variable* src, VARP dst) { return src == dst.get(); } inline bool operator!=(Variable* src, VARP dst) { return src != dst.get(); } // inline bool operator<(VARP src, VARP dst) { // return src.get() < dst.get(); // } typedef std::vector VARPS; class MNN_PUBLIC Variable { public: struct Info { Dimensionformat order = NHWC; INTS dim; halide_type_t type; int size; void syncSize(); }; const std::string& name() const; void setName(const std::string& name); std::pair expr() const { return std::make_pair(mFrom, mFromIndex); } // If compute info error, return nullptr const Info* getInfo(); bool resize(INTS dims); template const T* readMap() { return (const T*)readInternal(); } template T* writeMap() { return (T*)writeInternal(); } //Depecerate void unMap(); bool input(VARP src); static void replace(VARP dst, VARP src); static VARP create(EXPRP expr, int index = 0); static std::vector load(const char* fileName); static std::map loadMap(const char* fileName); static std::vector load(const uint8_t* buffer, size_t length); static std::map loadMap(const uint8_t* buffer, size_t length); static std::pair, std::map> getInputAndOutput(const std::map& allVariable); static std::vector mapToSequence(const std::map& source); static std::vector getExecuteOrder(const std::vector& output); static void save(const std::vector& vars, const char* fileName); static void save(const std::vector& vars, NetT* dest); // Pack a few Variable to compute in one pipeline static void prepareCompute(const std::vector& vars, bool forceCPU = false); size_t linkNumber() const; const std::vector& toExprs() const; void setExpr(EXPRP expr, int index) { mFrom = expr; mFromIndex = index; } private: Variable(EXPRP expr, int index) { mFrom = expr; mFromIndex = index; } void* readInternal(bool forShape = false); void* writeInternal(bool inform=true); void informDirty(); friend class Expr; EXPRP mFrom; int mFromIndex; }; class MNN_PUBLIC Expr { public: struct Inside; static EXPRP create(Variable::Info&& info, const void* ptr, VARP::InputType type, bool copy = true); static EXPRP create(const OpT* op, std::vector inputs, int outputSize = 1); static EXPRP create(std::pair, int> extra, std::vector&& inputs, int outputSize = 1); static EXPRP create(std::unique_ptr&& op, std::vector inputs, int outputSize = 1) { return create(op.get(), inputs, outputSize); } void setName(const std::string& name); const Op* get() const { return mOp; } const std::vector& inputs() const { return mInputs; } int outputSize() const { return (int)mOutputNames.size(); } static void replace(EXPRP oldExpr, EXPRP newExpr); bool requireInfo(); void visitOutputs(const std::function& visit); static void visit(EXPRP expr, const std::function& before, const std::function& after); const std::vector& outputs() const { return mTo; } ~Expr(); bool visited() const { return mVisited; } void setVisited(bool visited) { mVisited = visited; } const std::string& name() const { return mName; } const std::string& outputName(int index) { return mOutputNames[index]; } VARP::InputType inputType() const {return mType;} Variable::Info* outputInfo(int index) const; std::pair, int> extra() const { return std::make_pair(mExtraBuffer, mOpBufferSize); } bool setInfoDirty(); std::shared_ptr inside() const { return mInside; } bool valid() const { return mValid; } void setEntry(const std::vector& entries) { mEntries = entries; } const std::vector& getEntry() const { return mEntries; } private: static void _addLinkForInputs(EXPRP expr); Expr(int outputSize); friend class Variable; friend class VARP; VARP::InputType mType; const Op* mOp; std::vector mInputs; std::vector mOutputNames; bool mValid = true; std::shared_ptr mExtraBuffer; int mOpBufferSize = 0; std::string mName; std::shared_ptr mInside = nullptr; bool mVisited = false; std::vector mTo; // Only the enter input has entries, and it helps to get info for enter // input expression. std::vector mEntries; }; } // namespace Express } // namespace MNN #endif /* Expr_hpp */