C++与Lua交互实例 – 矩阵的加减乘除(版本二)
TIPS:关于使用矩阵的加减乘除测试C++与Lua的交互以及下面没讲述到的知识点可以阅读第一版:
https://blog.csdn.net/qq135595696/article/details/128960951
同时下面两个方式矩阵的数据都来源于C++端,只是第一种是在C++端进行结果比较展示,第二种方式(userdata)是在lua端进行结果比较展示。
下面C++端引入第三方开源库测试lua端矩阵的运算是否正确,参考链接如下:
http://eigen.tuxfamily.org/index.php?title=3.4
https://blog.csdn.net/qq_41854911/article/details/119814660
https://blog.csdn.net/thlzzz/article/details/110451022
CppToLua1
CppToLua.cpp
#include <iostream>
#include <vector>
#include <assert.h>
#include <Dense>
#include "lua.hpp"
using std::cout;
using std::endl;
using std::cin;
static int gs_Top = 0;
#define STACK_NUM(L) \
gs_Top = lua_gettop(L); \
std::cout<<"stack top:"<< gs_Top <<std::endl\
// 矩阵运算
enum class MATRIX_OPERATE {
ADD,
SUB,
MUL,
DIV,
NONE
};
#define LUA_SCRIPT_PATH "matrix2.0.lua"
static std::vector<std::vector<double>> gs_mat1;
static std::vector<std::vector<double>> gs_mat2;
static bool OutPrint(const std::vector<std::vector<double>>& data) {
for (int32_t i = 0; i < data.size(); i++) {
for (int32_t j = 0; j < data[0].size(); j++)
std::cout << " "<< data[i][j];
std::cout << '\n';
}
std::cout << "......\n";
return true;
}
static bool Init(lua_State* L) {
assert(NULL != L);
gs_mat1.clear();
gs_mat2.clear();
if (luaL_dofile(L, LUA_SCRIPT_PATH)) {
printf("%s\n", lua_tostring(L, -1));
return false;
}
return true;
}
static bool CreateLuaArr(lua_State* L, const std::vector<std::vector<double>>& data) {
assert(NULL != L);
//STACK_NUM(L);
lua_newtable(L);
for (int32_t i = 0; i < data.size(); i++) {
lua_newtable(L);
for (int32_t j = 0; j < data[0].size(); j++) {
lua_pushnumber(L, data[i][j]);
lua_rawseti(L, -2, j + 1);
}
lua_rawseti(L, -2, i + 1);
}
//STACK_NUM(L);
return true;
}
static bool GetLuaArr(lua_State* L, std::vector<std::vector<double>>& outData) {
assert(NULL != L);
outData.clear();
bool result = false;
int32_t row = 0;
int32_t col = 0;
if (LUA_TTABLE != lua_type(L, -1)) {
goto Exit;
}
if (LUA_TTABLE != lua_getfield(L, -1, "tbData")) {
goto Exit;
}
lua_getfield(L, -2, "nRow");
row = lua_tointeger(L, -1);
lua_pop(L, 1);
lua_getfield(L, -2, "nColumn");
col = lua_tointeger(L, -1);
lua_pop(L, 1);
for (int32_t i = 0; i < row; i++) {
lua_rawgeti(L, -1, i + 1);
std::vector<double> data;
for (int32_t j = 0; j < col; j++) {
lua_rawgeti(L, -1, j + 1);
data.push_back(lua_tonumber(L, -1));
lua_pop(L, 1);
}
outData.push_back(data);
lua_pop(L, 1);
}
//维持lua堆栈平衡
lua_pop(L, 1);
result = true;
Exit:
return true;
}
static bool MatrixOperate(lua_State* L,
std::vector<std::vector<double>>& outData, MATRIX_OPERATE type) {
outData.clear();
const char* funcName = NULL;
bool result = false;
switch (type) {
case MATRIX_OPERATE::ADD:
funcName = "MatrixAdd";
break;
case MATRIX_OPERATE::SUB:
funcName = "MatrixSub";
break;
case MATRIX_OPERATE::MUL:
funcName = "MatrixMul";
break;
case MATRIX_OPERATE::DIV:
funcName = "MatrixDiv";
break;
case MATRIX_OPERATE::NONE:
break;
default:
break;
}
lua_getglobal(L, funcName);
luaL_checktype(L, -1, LUA_TFUNCTION);
//添加形参
CreateLuaArr(L, gs_mat1);
CreateLuaArr(L, gs_mat2);
//调用函数
if (lua_pcall(L, 2, 1, 0)) {
printf("error[%s]\n", lua_tostring(L, -1));
goto Exit;
}
GetLuaArr(L, outData);
result = true;
Exit:
return result;
}
static bool APIMatrixOperate(const std::vector<std::vector<double>>& data1,
const std::vector<std::vector<double>>& data2, MATRIX_OPERATE type, Eigen::MatrixXd& outResMat) {
Eigen::MatrixXd mat1(data1.size(), data1[0].size());
Eigen::MatrixXd mat2(data2.size(), data2[0].size());
for (int i = 0; i < data1.size(); i++) {
for (int j = 0; j < data1[0].size(); j++) {
mat1(i, j) = data1[i][j];
}
}
for (int i = 0; i < data2.size(); i++) {
for (int j = 0; j < data2[0].size(); j++) {
mat2(i, j) = data2[i][j];
}
}
switch (type) {
case MATRIX_OPERATE::ADD:
outResMat = mat1 + mat2;
break;
case MATRIX_OPERATE::SUB:
outResMat = mat1 - mat2;
break;
case MATRIX_OPERATE::MUL:
outResMat = mat1 * mat2;
break;
case MATRIX_OPERATE::DIV:
outResMat = mat1 * (mat2.inverse());
break;
case MATRIX_OPERATE::NONE:
break;
default:
break;
}
return true;
}
static bool Run(lua_State* L) {
assert(NULL != L);
std::vector<std::vector<double>> addData;
std::vector<std::vector<double>> subData;
std::vector<std::vector<double>> mulData;
std::vector<std::vector<double>> divData;
Eigen::MatrixXd addApiData;
Eigen::MatrixXd subApiData;
Eigen::MatrixXd mulApiData;
Eigen::MatrixXd divApiData;
// 运算
gs_mat1 = { { 1,2,3 }, { 4,5,6 } };
gs_mat2 = { { 2,3,4 }, { 5,6,7 } };
MatrixOperate(L, addData, MATRIX_OPERATE::ADD);
APIMatrixOperate(gs_mat1, gs_mat2, MATRIX_OPERATE::ADD, addApiData);
gs_mat1 = addData;
gs_mat2 = { {1,1,1},{1,1,1} };
MatrixOperate(L, subData, MATRIX_OPERATE::SUB);
APIMatrixOperate(gs_mat1, gs_mat2, MATRIX_OPERATE::SUB, subApiData);
gs_mat1 = { {1,2,3},{4,5,6} };
gs_mat2 = { {7,8},{9,10},{11,12} };
MatrixOperate(L, mulData, MATRIX_OPERATE::MUL);
APIMatrixOperate(gs_mat1, gs_mat2, MATRIX_OPERATE::MUL, mulApiData);
gs_mat1 = { {41,2,3},{424,5,6},{742,8,11} };
gs_mat2 = { {1,2,1},{1,1,2},{2,1,1} };
MatrixOperate(L, divData, MATRIX_OPERATE::DIV);
APIMatrixOperate(gs_mat1, gs_mat2, MATRIX_OPERATE::DIV, divApiData);
// 输出
cout << "================加法:================" << endl;
OutPrint(addData);
cout << "正确答案:\n" << addApiData << endl;
cout << "================减法:================" << endl;
OutPrint(subData);
cout << "正确答案:\n" << subApiData << endl;
cout << "================乘法:================" << endl;
OutPrint(mulData);
cout << "正确答案:\n" << mulApiData << endl;
cout << "================除法:================" << endl;
OutPrint(divData);
cout << "正确答案:\n" << divApiData << endl;
return true;
}
static bool UnInit() {
return true;
}
int main020811() {
lua_State* L = luaL_newstate();
luaL_openlibs(L);
if (Init(L)) {
Run(L);
}
UnInit();
lua_close(L);
return 0;
}
matrix2.0.lua
local _class = {}
function class(super)
local tbClassType = {}
tbClassType.Ctor = false
tbClassType.super = super
tbClassType.New = function(...)
local tbObj = {}
do
local funcCreate
funcCreate = function(tbClass,...)
if tbClass.super then
funcCreate(tbClass.super,...)
end
if tbClass.Ctor then
tbClass.Ctor(tbObj,...)
end
end
funcCreate(tbClassType,...)
end
-- 防止调用Ctor初始化时,在Ctor内部设置了元表的情况发生
if getmetatable(tbObj) then
getmetatable(tbObj).__index = _class[tbClassType]
else
setmetatable(tbObj, { __index = _class[tbClassType] })
end
return tbObj
end
local vtbl = {}
_class[tbClassType] = vtbl
setmetatable(tbClassType, { __newindex =
function(tb,k,v)
vtbl[k] = v
end
})
if super then
setmetatable(vtbl, { __index =
function(tb,k)
local varRet = _class[super][k]
vtbl[k] = varRet
return varRet
end
})
end
return tbClassType
end
Matrix = class()
function Matrix:Ctor(data)
self.tbData = data
self.nRow = #data
if self.nRow > 0 then
self.nColumn = (#data[1])
else
self.nColumn = 0
end
-- print("row:",self.nRow," col:",self.nColumn)
setmetatable(self,{
__add = function(tbSource, tbDest)
assert(tbSource,"tbSource not exist")
assert(tbDest, "tbDest not exist")
local tbRes = Matrix.New({})
-- print(tbSource,tbDest)
-- print("tbSource:",tbSource.nRow,tbSource.nColumn)
-- tbSource:Print()
-- print("tbDest:",tbDest.nRow,tbDest.nColumn)
-- tbDest:Print()
if tbSource.nRow ~= tbDest.nRow
or tbSource.nColumn ~= tbDest.nColumn then
print("row or column not equal...")
return tbRes
else
for rowKey,rowValue in ipairs(tbSource.tbData) do
for colKey,colValue in ipairs(tbSource.tbData[rowKey]) do
if tbRes.tbData[rowKey] == nil then
tbRes.tbData[rowKey] = {}
end
if tbRes.tbData[rowKey][colKey] == nil then
tbRes.tbData[rowKey][colKey] = 0
end
tbRes.tbData[rowKey][colKey] =
tbSource.tbData[rowKey][colKey] + tbDest.tbData[rowKey][colKey]
end
end
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbSource.nColumn
return tbRes
end
end,
__sub = function(tbSource, tbDest)
assert(tbSource,"tbSource not exist")
assert(tbDest, "tbDest not exist")
local tbRes = Matrix.New({})
if tbSource.nRow ~= tbDest.nRow
or tbSource.nColumn ~= tbDest.nColumn then
print("row or column not equal...")
return tbRes
else
for rowKey,rowValue in ipairs(tbSource.tbData) do
for colKey,colValue in ipairs(tbSource.tbData[rowKey]) do
if tbRes.tbData[rowKey] == nil then
tbRes.tbData[rowKey] = {}
end
if tbRes.tbData[rowKey][colKey] == nil then
tbRes.tbData[rowKey][colKey] = 0
end
tbRes.tbData[rowKey][colKey] =
tbSource.tbData[rowKey][colKey] - tbDest.tbData[rowKey][colKey]
end
end
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbSource.nColumn
return tbRes
end
end,
__mul = function(tbSource, tbDest)
return self:_MartixMul(tbSource, tbDest)
end,
__div = function(tbSource, tbDest)
assert(tbSource,"tbSource not exist")
assert(tbDest, "tbDest not exist")
local nDet = self:_GetDetValue(tbDest)
if nDet == 0 then
print("matrix no inverse matrix...")
return nil
end
-- print("det ",nDet)
local tbInverseDest = self:_MatrixNumMul(self:_GetCompanyMatrix(tbDest), 1 / nDet)
-- self:_GetCompanyMatrix(tbDest):Print()
-- print(nDet)
-- tbInverseDest:Print()
return self:_MartixMul(tbSource, tbInverseDest)
end
}
)
end
function Matrix:Print()
for rowKey,rowValue in ipairs(self.tbData) do
for colKey,colValue in ipairs(self.tbData[rowKey]) do
io.write(self.tbData[rowKey][colKey],',')
end
print('')
end
end
-- 加
function Matrix:Add(matrix)
return self + matrix
end
-- 减
function Matrix:Sub(matrix)
return self - matrix
end
-- 乘
function Matrix:Mul(matrix)
return self * matrix
end
-- 除
function Matrix:Div(matrix)
return self / matrix
end
-- 切割,切去第rowIndex以及第colIndex列
function Matrix:_CutoffMatrix(tbMatrix, rowIndex, colIndex)
assert(tbMatrix,"tbMatrix not exist")
assert(rowIndex >= 1,"rowIndex < 1")
assert(colIndex >= 1,"colIndex < 1")
local tbRes = Matrix.New({})
tbRes.nRow = tbMatrix.nRow - 1
tbRes.nColumn = tbMatrix.nColumn - 1
for i = 1, tbMatrix.nRow - 1 do
for j = 1, tbMatrix.nColumn - 1 do
if tbRes.tbData[i] == nil then
tbRes.tbData[i] = {}
end
local nRowDir = 0
local nColDir = 0
if i >= rowIndex then
nRowDir = 1
end
if j >= colIndex then
nColDir = 1
end
tbRes.tbData[i][j] = tbMatrix.tbData[i + nRowDir][j + nColDir]
end
end
return tbRes
end
-- 获取矩阵的行列式对应的值
function Matrix:_GetDetValue(tbMatrix)
assert(tbMatrix,"tbMatrix not exist")
-- 当矩阵为一阶矩阵时,直接返回A中唯一的元素
if tbMatrix.nRow == 1 then
return tbMatrix.tbData[1][1]
end
local nAns = 0
for i = 1, tbMatrix.nColumn do
local nFlag = -1
if i % 2 ~= 0 then
nFlag = 1
end
nAns =
nAns + tbMatrix.tbData[1][i] *
self:_GetDetValue(self:_CutoffMatrix(tbMatrix, 1, i)) * nFlag
-- print("_GetDetValue nflag:",nFlag)
end
return nAns
end
-- 获取矩阵的伴随矩阵
function Matrix:_GetCompanyMatrix(tbMatrix)
assert(tbMatrix,"tbMatrix not exist")
local tbRes = Matrix.New({})
-- 伴随矩阵与原矩阵存在转置关系
tbRes.nRow = tbMatrix.nColumn
tbRes.nColumn = tbMatrix.nRow
for i = 1, tbMatrix.nRow do
for j = 1, tbMatrix.nColumn do
local nFlag = 1
if ((i + j) % 2) ~= 0 then
nFlag = -1
end
if tbRes.tbData[j] == nil then
tbRes.tbData[j] = {}
end
-- print(Matrix:_GetDetValue(Matrix:_CutoffMatrix(tbMatrix, i, j)))
-- Matrix:_CutoffMatrix(tbMatrix, i, j):Print()
-- print("---11----")
tbRes.tbData[j][i] =
self:_GetDetValue(self:_CutoffMatrix(tbMatrix, i, j)) * nFlag
end
end
return tbRes
end
-- 矩阵数乘
function Matrix:_MatrixNumMul(tbMatrix, num)
for i = 1, tbMatrix.nRow do
for j = 1, tbMatrix.nColumn do
tbMatrix.tbData[i][j] = tbMatrix.tbData[i][j] * num
end
end
return tbMatrix
end
-- 矩阵相乘
function Matrix:_MartixMul(tbSource, tbDest)
assert(tbSource,"tbSource not exist")
assert(tbDest, "tbDest not exist")
if tbSource.nColumn ~= tbDest.nRow then
print("column not equal row...")
return tbSource
else
local tbRes = Matrix.New({})
for i = 1, tbSource.nRow do
for j = 1, tbDest.nColumn do
if tbRes.tbData[i] == nil then
tbRes.tbData[i] = {}
end
if tbRes.tbData[i][j] == nil then
tbRes.tbData[i][j] = 0
end
for k = 1, tbSource.nColumn do
tbRes.tbData[i][j] =
tbRes.tbData[i][j] + (tbSource.tbData[i][k] * tbDest.tbData[k][j])
end
end
end
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbDest.nColumn
return tbRes
end
end
-- add
function MatrixAdd(data1, data2)
assert(data1,"data1 not exist")
assert(data2,"data2 not exist")
local matrix1 = Matrix.New(data1)
local matrix2 = Matrix.New(data2)
return matrix1 + matrix2
end
-- sub
function MatrixSub(data1, data2)
assert(data1,"data1 not exist")
assert(data2,"data2 not exist")
local matrix1 = Matrix.New(data1)
local matrix2 = Matrix.New(data2)
return matrix1 - matrix2
end
-- mul
function MatrixMul(data1, data2)
assert(data1,"data1 not exist")
assert(data2,"data2 not exist")
local matrix1 = Matrix.New(data1)
local matrix2 = Matrix.New(data2)
return matrix1 * matrix2
end
-- div
function MatrixDiv(data1, data2)
assert(data1,"data1 not exist")
assert(data2,"data2 not exist")
local matrix1 = Matrix.New(data1)
local matrix2 = Matrix.New(data2)
return matrix1 / matrix2
end
输出结果
文章来源:https://uudwc.com/A/1awG
CppToLua2
CppToLua.cpp
#include <iostream>
#include <Dense>
#include <vector>
#include "lua.hpp"
using std::cout;
using std::endl;
using std::cin;
#define CPP_MATRIX "CPP_MATRIX"
#define LUA_SCRIPT_PATH "matrix2.0-lua.lua"
static int gs_Top = 0;
#define STACK_NUM(L) \
gs_Top = lua_gettop(L); \
std::cout<<"stack top:"<< gs_Top <<std::endl\
// 矩阵运算
enum class MATRIX_OPERATE {
ADD,
SUB,
MUL,
DIV,
NONE
};
static std::vector<std::vector<double>> gs_mat1;
static std::vector<std::vector<double>> gs_mat2;
extern "C" {
static int CreateMatrix(lua_State* L) {
Eigen::MatrixXd** pp =
(Eigen::MatrixXd**)lua_newuserdata(L, sizeof(Eigen::MatrixXd*));
*pp = new Eigen::MatrixXd();
luaL_setmetatable(L, CPP_MATRIX);
return 1;
}
static int InitMatrix(lua_State* L) {
assert(NULL != L);
int32_t row = 0;
int32_t col = 0;
row = luaL_len(L, -1);
lua_rawgeti(L, -1, 1);
col = luaL_len(L, -1);
lua_pop(L, 1);
Eigen::MatrixXd** pp =
(Eigen::MatrixXd**)luaL_checkudata(L, 1, CPP_MATRIX);
(*pp)->resize(row, col);
for (int32_t i = 0; i < row; i++) {
lua_rawgeti(L, -1, i + 1);
for (int32_t j = 0; j < col; j++) {
lua_rawgeti(L, -1, j + 1);
(**pp)(i, j) = lua_tonumber(L, -1);
lua_pop(L, 1);
}
lua_pop(L, 1);
}
lua_pop(L, 2);
return 0;
}
static int UnInitMatrix(lua_State* L) {
Eigen::MatrixXd** pp =
(Eigen::MatrixXd**)luaL_checkudata(L, 1, CPP_MATRIX);
std::cout << "auto gc" << std::endl;
if (*pp) {
delete *pp;
}
return 0;
}
static int AddMatrix(lua_State* L) {
//STACK_NUM(L);
Eigen::MatrixXd** pp1 =
(Eigen::MatrixXd**)luaL_checkudata(L, 1, CPP_MATRIX);
Eigen::MatrixXd** pp2 =
(Eigen::MatrixXd**)luaL_checkudata(L, 2, CPP_MATRIX);
Eigen::MatrixXd** pp =
(Eigen::MatrixXd**)lua_newuserdata(L, sizeof(Eigen::MatrixXd*));
*pp = new Eigen::MatrixXd(); //该部分内存由C++分配
**pp = (**pp1) + (**pp2);
luaL_setmetatable(L, CPP_MATRIX);
//STACK_NUM(L);
return 1;
}
static int SubMatrix(lua_State* L) {
//STACK_NUM(L);
Eigen::MatrixXd** pp1 =
(Eigen::MatrixXd**)luaL_checkudata(L, 1, CPP_MATRIX);
Eigen::MatrixXd** pp2 =
(Eigen::MatrixXd**)luaL_checkudata(L, 2, CPP_MATRIX);
Eigen::MatrixXd** pp =
(Eigen::MatrixXd**)lua_newuserdata(L, sizeof(Eigen::MatrixXd*));
*pp = new Eigen::MatrixXd(); //该部分内存由C++分配
**pp = (**pp1) - (**pp2);
luaL_setmetatable(L, CPP_MATRIX);
//STACK_NUM(L);
return 1;
}
static int MulMatrix(lua_State* L) {
//STACK_NUM(L);
Eigen::MatrixXd** pp1 =
(Eigen::MatrixXd**)luaL_checkudata(L, 1, CPP_MATRIX);
Eigen::MatrixXd** pp2 =
(Eigen::MatrixXd**)luaL_checkudata(L, 2, CPP_MATRIX);
Eigen::MatrixXd** pp =
(Eigen::MatrixXd**)lua_newuserdata(L, sizeof(Eigen::MatrixXd*));
*pp = new Eigen::MatrixXd(); //该部分内存由C++分配
**pp = (**pp1) * (**pp2);
luaL_setmetatable(L, CPP_MATRIX);
//STACK_NUM(L);
return 1;
}
static int DivMatrix(lua_State* L) {
//STACK_NUM(L);
Eigen::MatrixXd** pp1 =
(Eigen::MatrixXd**)luaL_checkudata(L, 1, CPP_MATRIX);
Eigen::MatrixXd** pp2 =
(Eigen::MatrixXd**)luaL_checkudata(L, 2, CPP_MATRIX);
Eigen::MatrixXd** pp =
(Eigen::MatrixXd**)lua_newuserdata(L, sizeof(Eigen::MatrixXd*));
*pp = new Eigen::MatrixXd(); //该部分内存由C++分配
**pp = (**pp1) * ((**pp2).inverse());
luaL_setmetatable(L, CPP_MATRIX);
//STACK_NUM(L);
return 1;
}
static int PrintMatrix(lua_State* L) {
Eigen::MatrixXd** pp =
(Eigen::MatrixXd**)luaL_checkudata(L, 1, CPP_MATRIX);
std::cout << "正确答案:\n" << **pp << std::endl;
return 0;
}
}
static const luaL_Reg MatrixFuncs[] = {
{"InitMatrix", InitMatrix },
{"__gc", UnInitMatrix},
{"__add", AddMatrix },
{"__sub", SubMatrix },
{"__mul", MulMatrix },
{"__div", DivMatrix },
{"PrintMatrix",PrintMatrix },
{NULL, NULL }
};
extern "C" {
static bool CreateMatrixMetaTable(lua_State* L) {
luaL_newmetatable(L, CPP_MATRIX);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
luaL_setfuncs(L, MatrixFuncs, 0);
//STACK_NUM(L);
lua_pop(L, 1);
return true;
}
}
bool CreateLuaArr(lua_State* L, const std::vector<std::vector<double>>& data) {
assert(NULL != L);
//STACK_NUM(L);
lua_newtable(L);
for (int32_t i = 0; i < data.size(); i++) {
lua_newtable(L);
for (int32_t j = 0; j < data[0].size(); j++) {
lua_pushnumber(L, data[i][j]);
lua_rawseti(L, -2, j + 1);
}
lua_rawseti(L, -2, i + 1);
}
//STACK_NUM(L);
return true;
}
bool MatrixOperate(lua_State* L, MATRIX_OPERATE type) {
const char* funcName = NULL;
bool result = false;
switch (type) {
case MATRIX_OPERATE::ADD:
funcName = "MatrixAdd";
break;
case MATRIX_OPERATE::SUB:
funcName = "MatrixSub";
break;
case MATRIX_OPERATE::MUL:
funcName = "MatrixMul";
break;
case MATRIX_OPERATE::DIV:
funcName = "MatrixDiv";
break;
case MATRIX_OPERATE::NONE:
break;
default:
break;
}
lua_getglobal(L, funcName);
luaL_checktype(L, -1, LUA_TFUNCTION);
//添加形参
CreateLuaArr(L, gs_mat1);
CreateLuaArr(L, gs_mat2);
//调用函数
if (lua_pcall(L, 2, 0, 0)) {
printf("error[%s]\n", lua_tostring(L, -1));
goto Exit;
}
result = true;
Exit:
return result;
}
bool Init(lua_State *L) {
//构造一张全局元表,名为CPP_MATRIX
CreateMatrixMetaTable(L);
//注册第三方API构造对象方法
lua_pushcfunction(L, CreateMatrix);
lua_setglobal(L, "CreateMatrix");
if (luaL_dofile(L, LUA_SCRIPT_PATH)) {
printf("%s\n", lua_tostring(L, -1));
}
return true;
}
bool Run(lua_State* L) {
assert(NULL != L);
// 运算
gs_mat1 = { { 1,2,3 }, { 4,5,6 } };
gs_mat2 = { { 2,3,4 }, { 5,6,7 } };
MatrixOperate(L, MATRIX_OPERATE::ADD);
gs_mat1 = { { 1,2,3 }, { 4,5,6 } };
gs_mat2 = { { 1,1,1 }, { 1,1,1 } };
MatrixOperate(L, MATRIX_OPERATE::SUB);
gs_mat1 = { {1,2,3},{4,5,6} };
gs_mat2 = { {7,8},{9,10},{11,12} };
MatrixOperate(L, MATRIX_OPERATE::MUL);
gs_mat1 = { {41,2,3},{424,5,6},{742,8,11} };
gs_mat2 = { {1,2,1},{1,1,2},{2,1,1} };
MatrixOperate(L, MATRIX_OPERATE::DIV);
return true;
}
bool UnInit() {
return true;
}
int main() {
lua_State* L = luaL_newstate();
luaL_openlibs(L);
if (Init(L)) {
Run(L);
}
UnInit();
lua_close(L);
return 0;
}
matrix2.0-lua.lua
local _class = {}
function class(super)
local tbClassType = {}
tbClassType.Ctor = false
tbClassType.super = super
tbClassType.New = function(...)
local tbObj = {}
do
local funcCreate
funcCreate = function(tbClass,...)
if tbClass.super then
funcCreate(tbClass.super,...)
end
if tbClass.Ctor then
tbClass.Ctor(tbObj,...)
end
end
funcCreate(tbClassType,...)
end
-- 防止调用Ctor初始化时,在Ctor内部设置了元表的情况发生
if getmetatable(tbObj) then
getmetatable(tbObj).__index = _class[tbClassType]
else
setmetatable(tbObj, { __index = _class[tbClassType] })
end
return tbObj
end
local vtbl = {}
_class[tbClassType] = vtbl
setmetatable(tbClassType, { __newindex =
function(tb,k,v)
vtbl[k] = v
end
})
if super then
setmetatable(vtbl, { __index =
function(tb,k)
local varRet = _class[super][k]
vtbl[k] = varRet
return varRet
end
})
end
return tbClassType
end
Matrix = class()
function Matrix:Ctor(data)
self.tbData = data
self.nRow = #data
if self.nRow > 0 then
self.nColumn = (#data[1])
else
self.nColumn = 0
end
-- print("row:",self.nRow," col:",self.nColumn)
setmetatable(self,{
__add = function(tbSource, tbDest)
assert(tbSource,"tbSource not exist")
assert(tbDest, "tbDest not exist")
local tbRes = Matrix.New({})
-- print(tbSource,tbDest)
-- print("tbSource:",tbSource.nRow,tbSource.nColumn)
-- tbSource:Print()
-- print("tbDest:",tbDest.nRow,tbDest.nColumn)
-- tbDest:Print()
if tbSource.nRow ~= tbDest.nRow
or tbSource.nColumn ~= tbDest.nColumn then
print("row or column not equal...")
return tbRes
else
for rowKey,rowValue in ipairs(tbSource.tbData) do
for colKey,colValue in ipairs(tbSource.tbData[rowKey]) do
if tbRes.tbData[rowKey] == nil then
tbRes.tbData[rowKey] = {}
end
if tbRes.tbData[rowKey][colKey] == nil then
tbRes.tbData[rowKey][colKey] = 0
end
tbRes.tbData[rowKey][colKey] =
tbSource.tbData[rowKey][colKey] + tbDest.tbData[rowKey][colKey]
end
end
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbSource.nColumn
return tbRes
end
end,
__sub = function(tbSource, tbDest)
assert(tbSource,"tbSource not exist")
assert(tbDest, "tbDest not exist")
local tbRes = Matrix.New({})
if tbSource.nRow ~= tbDest.nRow
or tbSource.nColumn ~= tbDest.nColumn then
print("row or column not equal...")
return tbRes
else
for rowKey,rowValue in ipairs(tbSource.tbData) do
for colKey,colValue in ipairs(tbSource.tbData[rowKey]) do
if tbRes.tbData[rowKey] == nil then
tbRes.tbData[rowKey] = {}
end
if tbRes.tbData[rowKey][colKey] == nil then
tbRes.tbData[rowKey][colKey] = 0
end
tbRes.tbData[rowKey][colKey] =
tbSource.tbData[rowKey][colKey] - tbDest.tbData[rowKey][colKey]
end
end
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbSource.nColumn
return tbRes
end
end,
__mul = function(tbSource, tbDest)
return self:_MartixMul(tbSource, tbDest)
end,
__div = function(tbSource, tbDest)
assert(tbSource,"tbSource not exist")
assert(tbDest, "tbDest not exist")
local nDet = self:_GetDetValue(tbDest)
if nDet == 0 then
print("matrix no inverse matrix...")
return nil
end
-- print("det ",nDet)
local tbInverseDest = self:_MatrixNumMul(self:_GetCompanyMatrix(tbDest), 1 / nDet)
-- self:_GetCompanyMatrix(tbDest):Print()
-- print(nDet)
-- tbInverseDest:Print()
return self:_MartixMul(tbSource, tbInverseDest)
end
}
)
end
function Matrix:Print()
for rowKey,rowValue in ipairs(self.tbData) do
for colKey,colValue in ipairs(self.tbData[rowKey]) do
io.write(self.tbData[rowKey][colKey],',')
end
print('')
end
end
-- 加
function Matrix:Add(matrix)
return self + matrix
end
-- 减
function Matrix:Sub(matrix)
return self - matrix
end
-- 乘
function Matrix:Mul(matrix)
return self * matrix
end
-- 除
function Matrix:Div(matrix)
return self / matrix
end
-- 切割,切去第rowIndex以及第colIndex列
function Matrix:_CutoffMatrix(tbMatrix, rowIndex, colIndex)
assert(tbMatrix,"tbMatrix not exist")
assert(rowIndex >= 1,"rowIndex < 1")
assert(colIndex >= 1,"colIndex < 1")
local tbRes = Matrix.New({})
tbRes.nRow = tbMatrix.nRow - 1
tbRes.nColumn = tbMatrix.nColumn - 1
for i = 1, tbMatrix.nRow - 1 do
for j = 1, tbMatrix.nColumn - 1 do
if tbRes.tbData[i] == nil then
tbRes.tbData[i] = {}
end
local nRowDir = 0
local nColDir = 0
if i >= rowIndex then
nRowDir = 1
end
if j >= colIndex then
nColDir = 1
end
tbRes.tbData[i][j] = tbMatrix.tbData[i + nRowDir][j + nColDir]
end
end
return tbRes
end
-- 获取矩阵的行列式对应的值
function Matrix:_GetDetValue(tbMatrix)
assert(tbMatrix,"tbMatrix not exist")
-- 当矩阵为一阶矩阵时,直接返回A中唯一的元素
if tbMatrix.nRow == 1 then
return tbMatrix.tbData[1][1]
end
local nAns = 0
for i = 1, tbMatrix.nColumn do
local nFlag = -1
if i % 2 ~= 0 then
nFlag = 1
end
nAns =
nAns + tbMatrix.tbData[1][i] *
self:_GetDetValue(self:_CutoffMatrix(tbMatrix, 1, i)) * nFlag
-- print("_GetDetValue nflag:",nFlag)
end
return nAns
end
-- 获取矩阵的伴随矩阵
function Matrix:_GetCompanyMatrix(tbMatrix)
assert(tbMatrix,"tbMatrix not exist")
local tbRes = Matrix.New({})
-- 伴随矩阵与原矩阵存在转置关系
tbRes.nRow = tbMatrix.nColumn
tbRes.nColumn = tbMatrix.nRow
for i = 1, tbMatrix.nRow do
for j = 1, tbMatrix.nColumn do
local nFlag = 1
if ((i + j) % 2) ~= 0 then
nFlag = -1
end
if tbRes.tbData[j] == nil then
tbRes.tbData[j] = {}
end
-- print(Matrix:_GetDetValue(Matrix:_CutoffMatrix(tbMatrix, i, j)))
-- Matrix:_CutoffMatrix(tbMatrix, i, j):Print()
-- print("---11----")
tbRes.tbData[j][i] =
self:_GetDetValue(self:_CutoffMatrix(tbMatrix, i, j)) * nFlag
end
end
return tbRes
end
-- 矩阵数乘
function Matrix:_MatrixNumMul(tbMatrix, num)
for i = 1, tbMatrix.nRow do
for j = 1, tbMatrix.nColumn do
tbMatrix.tbData[i][j] = tbMatrix.tbData[i][j] * num
end
end
return tbMatrix
end
-- 矩阵相乘
function Matrix:_MartixMul(tbSource, tbDest)
assert(tbSource,"tbSource not exist")
assert(tbDest, "tbDest not exist")
if tbSource.nColumn ~= tbDest.nRow then
print("column not equal row...")
return tbSource
else
local tbRes = Matrix.New({})
for i = 1, tbSource.nRow do
for j = 1, tbDest.nColumn do
if tbRes.tbData[i] == nil then
tbRes.tbData[i] = {}
end
if tbRes.tbData[i][j] == nil then
tbRes.tbData[i][j] = 0
end
for k = 1, tbSource.nColumn do
tbRes.tbData[i][j] =
tbRes.tbData[i][j] + (tbSource.tbData[i][k] * tbDest.tbData[k][j])
end
end
end
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbDest.nColumn
return tbRes
end
end
-- add
function MatrixAdd(data1, data2)
assert(data1,"data1 not exist")
assert(data2,"data2 not exist")
local matrix1 = Matrix.New(data1)
local matrix2 = Matrix.New(data2)
local matrix3 = matrix1 + matrix2
print("===========加法===========")
matrix3:Print()
local cppMatrix1 = CreateMatrix()
cppMatrix1:InitMatrix(data1)
local cppMatrix2 = CreateMatrix()
cppMatrix2:InitMatrix(data2)
local cppMatrix3 = cppMatrix1 + cppMatrix2
cppMatrix3:PrintMatrix()
end
-- sub
function MatrixSub(data1, data2)
assert(data1,"data1 not exist")
assert(data2,"data2 not exist")
local matrix1 = Matrix.New(data1)
local matrix2 = Matrix.New(data2)
local matrix3 = matrix1 - matrix2
print("===========减法===========")
matrix3:Print()
local cppMatrix1 = CreateMatrix()
cppMatrix1:InitMatrix(data1)
local cppMatrix2 = CreateMatrix()
cppMatrix2:InitMatrix(data2)
local cppMatrix3 = cppMatrix1 - cppMatrix2
cppMatrix3:PrintMatrix()
end
-- mul
function MatrixMul(data1, data2)
assert(data1,"data1 not exist")
assert(data2,"data2 not exist")
local matrix1 = Matrix.New(data1)
local matrix2 = Matrix.New(data2)
local matrix3 = matrix1 * matrix2
print("===========乘法===========")
matrix3:Print()
local cppMatrix1 = CreateMatrix()
cppMatrix1:InitMatrix(data1)
local cppMatrix2 = CreateMatrix()
cppMatrix2:InitMatrix(data2)
local cppMatrix3 = cppMatrix1 * cppMatrix2
cppMatrix3:PrintMatrix()
end
-- div
function MatrixDiv(data1, data2)
assert(data1,"data1 not exist")
assert(data2,"data2 not exist")
local matrix1 = Matrix.New(data1)
local matrix2 = Matrix.New(data2)
local matrix3 = matrix1 / matrix2
print("===========除法===========")
matrix3:Print()
local cppMatrix1 = CreateMatrix()
cppMatrix1:InitMatrix(data1)
local cppMatrix2 = CreateMatrix()
cppMatrix2:InitMatrix(data2)
local cppMatrix3 = cppMatrix1 / cppMatrix2
cppMatrix3:PrintMatrix()
end
输出结果
文章来源地址https://uudwc.com/A/1awG