Files
2024-12-13 23:33:37 +08:00

65 lines
1.6 KiB
C++

//
// Optimizer.hpp
// MNN
//
// Created by MNN on 2019/08/20.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef Optimizer_hpp
#define Optimizer_hpp
#include <MNN/expr/Expr.hpp>
#include <MNN/MNNForwardType.h>
namespace MNN {
namespace Express {
class MNN_PUBLIC Optimizer {
public:
enum Device {
CPU = 0,
GPU = 1,
OTHER = 2,
AUTO = 3
};
struct Config {
Device device = CPU;
MNNForwardType forwardType = MNN_FORWARD_ALL;
int numThread = 4;
};
static std::shared_ptr<Optimizer> create(Config config);
struct Cost {
float compute; // MFlops
float memory; // MB
};
class Parameters {
public:
Parameters(int n);
virtual ~Parameters();
float* get() const {
return mValue;
}
int size() const {
return mSize;
}
private:
float* mValue;
int mSize;
};
virtual std::shared_ptr<Parameters> onGetParameters(const std::vector<VARP>& outputs) {
return nullptr;
}
//Given paramters and measure cost, the parameters must be the same as onGetParameters
virtual Cost onMeasure(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) = 0;
//Modify the output directly, the parameters must be the same as onGetParameters
virtual bool onExecute(const std::vector<VARP>& outputs, std::shared_ptr<Parameters> parameters = nullptr) = 0;
Optimizer() = default;
virtual ~Optimizer() = default;
};
} // namespace Express
} // namespace MNN
#endif