CSV格式
概念
CSV(逗号分隔值)文件是一种特殊的文件类型,可在 Excel 中创建或编辑。CSV文件采用逗号分隔的形式来存储文本和数字信息,总体来说,这种形式的文件格式具有扩展性好,移植性强的特点。
目前许多主流程序采用CSV文件作为数据导入导出的中间格式,例如MySQL数据库可以从CSV文件中导入数据,GMail联系人可以导出到CSV文件,然后将其导入到Outlook中。
pytorch和KuiperInfer中间数据对比文件,用于读取Pytorch复杂,且多变维度的输出,来对比我们的推理结果。
从pytorch输出(被定位到csv文件)中读取,KuiperInfer读取,然后再对比.
过程是这样的:
pytorch读取权重文件--》infer--》numpy--》再保存到本地,格式为csv -->读取并初始化 tensor
kuiper 读取权重文件-->保存为tensor -->
绿色背景就是我们今天要去做的。
代码:
其实这个也不会很难:
std::shared_ptr<Tensor<float >> CSVDataLoader::LoadData(const std::string &file_path, char split_char) {
CHECK(!file_path.empty()) << "File path is empty!";
std::ifstream in(file_path);
CHECK(in.is_open() && in.good()) << "File open failed! " << file_path;
std::string line_str;
std::stringstream line_stream;
const auto &[rows, cols] = CSVDataLoader::GetMatrixSize(in, split_char);
std::shared_ptr<Tensor<float>> input_tensor = std::make_shared<Tensor<float>>(1, rows, cols);
arma::fmat &data = input_tensor->at(0);
size_t row = 0;
while (in.good()) {
std::getline(in, line_str);
if (line_str.empty()) {
break;
}
std::string token;
line_stream.clear();
line_stream.str(line_str);
size_t col = 0;
while (line_stream.good()) {
std::getline(line_stream, token, split_char);
try {
data.at(row, col) = std::stof(token);
}
catch (std::exception &e) {
LOG(ERROR) << "Parse CSV File meet error: " << e.what();
continue;
}
col += 1;
CHECK(col <= cols) << "There are excessive elements on the column";
}
row += 1;
CHECK(row <= rows) << "There are excessive elements on the row";
}
return input_tensor;
}
核心区段:
const auto &[rows, cols] = CSVDataLoader::GetMatrixSize(in, split_char); std::shared_ptr<Tensor<float>> input_tensor = std::make_shared<Tensor<float>>(1, rows, cols); arma::fmat &data = input_tensor->at(0);
确认rows和cols,创建空tensor
while (line_stream.good()) { std::getline(line_stream, token, split_char); try { data.at(row, col) = std::stof(token); } catch (std::exception &e) { LOG(ERROR) << "Parse CSV File meet error: " << e.what(); continue; } col += 1; CHECK(col <= cols) << "There are excessive elements on the column"; } row += 1; CHECK(row <= rows) << "There are excessive elements on the row"; } return input_tensor;
迭代,读进来赋值,返回Tensor
这里唯一需要注意的就是有表头的情况:文章来源:https://uudwc.com/A/od48E
try {
//todo 补充
// 能够读取到第一行的csv列名,并存放在headers中
// 能够读取到第二行之后的csv数据,并相应放置在data变量的row,col位置中
if (row == 0)
{
headers.push_back(token);
}
else{
data.at(row-1, col) = std::stof(token);
}
}
在try中要先将表头传到header里,再将其余的数组读到data中,但记住data中的row要-1,因为要扣除表头的一行。文章来源地址https://uudwc.com/A/od48E