Update tflite_support to v0.3.1
Quite a few patches needed, see the README for more details.
Sadly this does come with a binary size increase because of additional
files needed from tflite. This update is important for continued
code stability and security.
Bug: 1248206
Binary-Size: Size increase is unavoidable (see above).
Change-Id: I17d553da036617d51897dad3a370cc09e7cd9eba
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3345972
Reviewed-by: Dirk Pranke <dpranke@google.com>
Reviewed-by: Josh Simmons <jds@google.com>
Reviewed-by: Tommy Nyquist <nyquist@chromium.org>
Reviewed-by: Michael Crouse <mcrouse@chromium.org>
Reviewed-by: Ravjit Uppal <ravjit@chromium.org>
Commit-Queue: Robert Ogden <robertogden@chromium.org>
Cr-Commit-Position: refs/heads/main@{#956539}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
e0582ecaa5
commit
f4cac4fd17
components
optimization_guide
permissions
prediction_service
segmentation_platform
internal
execution
translate
core
language_detection
infra
third_party
tflite
tflite_support
BUILD.gnREADME.chromium
patches
0001-Fix-signed-comparison-in-base_vision_task_api.h.patch0001-Remove-signed-comparison-in-frame_buffer.h.patch0001-Remove-unused-qualifiers-in-frame_buffer.h.patch0001-Use-third_party-libyuv.patch0001-add-metadata-name-check.patch0001-bert-max-seq-len.patch0001-no-absl-cord.patch0001-task-utils-sign-compare.patch0001-use-StringPiece-for-string_view.patch0001-use-SysNSStringToUTF8.patch0001-use-base-logging.patch0001-use-exit.patch0001-use-re2-StringPiece-for-RegexTokenizer-Tokenize.patch0001-use-size_t.patch0002-sentencepiece-tokenization-not-supported.patch0003-rm-unused-func.patch0004-rm-noop-deprecated-attribute.patch0005-use-size_t-in-for-loop.patch0006-unused-variable.patch0007-do-not-use-absl-any.patch0008-unused-string-include.patch0009-remove-unbuilt-files-and-change-exec-bit-where-neede.patch0010-only-support-model-file-passed-in-from-mem.patch0011-run-clang-format.patch
src
.bazelrc.bazelversionREADME.mdWORKSPACE
tensorflow_lite_support
BUILD
acceleration
README.md
configuration
c
BUILDcommon.cccommon.hcommon_utils.cccommon_utils.h
task
core
processor
BUILDbounding_box.hcategory.hclassification_options.hclassification_result.ccclassification_result.hdetection_result.h
text
BUILDbert_nl_classifier.ccbert_nl_classifier.hbert_question_answerer.ccbert_question_answerer.hnl_classifier.ccnl_classifier.hnl_classifier_common.ccnl_classifier_common.h
vision
test
task
cc
BUILDcommon.cccommon.h
port
task
README.md
audio
BUILDaudio_classifier.ccaudio_classifier.haudio_embedder.ccaudio_embedder.h
core
proto
BUILDaudio_classifier_options.protoaudio_embedder_options.protoclass_proto_inc.hclassifications_proto_inc.h
utils
core
BUILDbase_task_api.hcategory.hclassification_head.ccclassification_head.herror_reporter.ccerror_reporter.hexternal_file_handler.ccexternal_file_handler.hlabel_map_item.cclabel_map_item.h
proto
score_calibration.ccscore_calibration.htask_api_factory.htask_utils.cctask_utils.htflite_engine.cctflite_engine.hprocessor
BUILDaudio_preprocessor.ccaudio_preprocessor.hbert_preprocessor.ccbert_preprocessor.hclassification_postprocessor.ccclassification_postprocessor.hembedding_postprocessor.ccembedding_postprocessor.himage_preprocessor.ccimage_preprocessor.hprocessor.ccprocessor.h
proto
BUILDclass.protoclassification_options.protoclassifications.protoembedding.protoembedding_options.proto
regex_preprocessor.ccregex_preprocessor.htext_preprocessor.cctext_preprocessor.htext
BUILDbert_nl_classifier.ccbert_nl_classifier.hbert_question_answerer.ccbert_question_answerer.h
nlclassifier
BUILDbert_nl_classifier_c_api.ccbert_nl_classifier_c_api.hnl_classifier.ccnl_classifier.hnl_classifier_c_api.h
proto
BUILDbert_nl_classifier_options.protobert_nl_classifier_options_proto_inc.hbert_question_answerer_options.protobert_question_answerer_options_proto_inc.hnl_classifier_options.protonl_classifier_options_proto_inc.hretrieval.protoretrieval_proto_inc.h
qa
question_answerer.huniversal_sentence_encoder_qa.ccuniversal_sentence_encoder_qa.hvision
BUILD
core
BUILDbase_vision_task_api.hclassification_head.ccclassification_head.hframe_buffer.hlabel_map_item.cclabel_map_item.h
image_classifier.ccimage_classifier.himage_embedder.ccimage_embedder.himage_segmenter.ccimage_segmenter.hobject_detector.ccobject_detector.hproto
BUILDclass_proto_inc.hclassifications_proto_inc.hdetections_proto_inc.hembeddings.protoembeddings_proto_inc.himage_classifier_options.protoimage_classifier_options_proto_inc.himage_embedder_options.protoimage_embedder_options_proto_inc.himage_segmenter_options.protoimage_segmenter_options_proto_inc.hobject_detector_options.protoobject_detector_options_proto_inc.hsegmentations_proto_inc.h
utils
test
BUILDcommon_test.ccmessage_matchers.h
task
test_utils.cctest_utils.htestdata
task
text
BUILDalbert_with_metadata.jsonbert_nl_classifier.jsonempty_vocab_for_regex_tokenizer.txtmobilebert_with_metadata.jsontest_model_nl_classifier.tflitetest_model_nl_classifier_bool_output.tflitetest_model_nl_classifier_with_associated_label.jsontest_model_nl_classifier_with_associated_label.tflitetest_model_nl_classifier_with_associated_label_builtin_ops.jsontest_model_nl_classifier_with_associated_label_builtin_ops.tflitetest_model_nl_classifier_with_regex_tokenizer.jsontest_model_nl_classifier_with_regex_tokenizer.tflitevocab_for_regex_tokenizer.txt
vision
BUILDautoml_labeler_model.tfliteburger-224.pngburger.jpgburger_crop.jpgburger_rotation180.jpgcats_and_dogs.jpgcats_and_dogs_rotation180.jpgcoco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflitecoco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflitecoco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflitedeeplabv3.tflitedilated_conv.tflitemobilenet_v1_0.25_224_1_default_1.tflitemobilenet_v1_0.25_224_1_metadata_1.tflitemobilenet_v1_0.25_224_quant.tflitemobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflitemobilenet_v2_1.0_224.tflitemobilenet_v2_1.0_224_without_labels.jsonmobilenet_v2_1.0_224_without_labels.tflitemobilenet_v3_small_100_224_embedder.tflitemulti_objects.jpgsegmentation_golden_rotation0.pngsegmentation_golden_rotation0_yuv.pngsegmentation_golden_rotation90_flop.pngsegmentation_input_rotation0.jpgsegmentation_input_rotation90_flop.jpgsparrow.png
text
tokenizers
utils
codegen
custom_ops
examples
task
ios
BUILDTensorFlowLiteTaskText.podspec.templateallowlist_TensorFlowLiteTaskText.txtallowlist_TensorFlowLiteTaskVision.txt
sources
task
core
processor
BUILD
sources
TFLClassificationOptions+Helpers.hTFLClassificationOptions+Helpers.mTFLClassificationOptions.hTFLClassificationOptions.mTFLClassificationResult.hTFLClassificationResult.m
utils
text
vision
test
task
vision
image_classifier
text
tokenizers
java
BUILDREADME.mddefault_version_script.lds
src
java
org
tensorflow
lite
support
audio
common
image
BitmapContainer.javaBoundingBoxUtil.javaColorSpaceType.javaImageContainer.javaImageConversions.javaImageProcessor.javaImageProperties.javaMediaImageContainer.javaMlImageAdapter.javaTensorBufferContainer.javaTensorImage.java
ops
label
model
tensorbuffer
task
javatests
org
tensorflow
lite
support
AndroidManifest.xmlBUILD
assets
audio
common
image
BoundingBoxUtilTest.javaColorSpaceTypeInstrumentedTest.javaColorSpaceTypeTest.javaImageConversionsInstrumentedTest.javaImageConversionsTest.javaImageProcessorInstrumentedTest.javaImageProcessorTest.javaMlImageAdapterTest.javaTensorImageInstrumentedTest.javaTensorImageTest.javaTestImageCreator.java
ops
label
model
tensorbuffer
native
task
metadata
cc
flatbuffers_lib
java
metadata_schema.fbspython
BUILD__init__.pymetadata.pymetadata_displayer.pymetadata_writer_for_task.py
metadata_writers
BUILD__init__.pyaudio_classifier.pybert_nl_classifier.pyimage_classifier.pyimage_segmenter.pymetadata_info.pymetadata_writer.pynl_classifier.pyobject_detector.pywriter_utils.py
tests
BUILDmetadata_parser_test.pymetadata_test.pymetadata_writer_for_task_test.py
metadata_writers
BUILDaudio_classifier_test.pybert_nl_classifier_test.pyimage_classifier_test.pyimage_segmenter_test.pymetadata_info_test.pymetadata_writer_test.pynl_classifier_test.pyobject_detector_test.pytest_utils.pywriter_utils_test.py
testdata
BUILDassociated_file_meta.json
audio_classifier
BUILDdaredevil_sound_recognizer_320ms.jsonlabelmap.txttwo_heads.jsontwo_heads.tflitetwo_heads_default.jsonyamnet_521_labels.txtyamnet_tfhub.jsonyamnet_tfhub.tfliteyamnet_wavin_quantized_mel_relu6.jsonyamnet_wavin_quantized_mel_relu6.tfliteyamnet_wavin_quantized_mel_relu6_default.json
audio_embedder
bert_nl_classifier
BUILDbert_nl_classifier_default.jsonbert_nl_classifier_with_bert_tokenizer.jsonbert_nl_classifier_with_sentence_piece.jsonlabels.txt
bert_tokenizer_meta.jsonbounding_box_tensor_meta.jsoncategory_tensor_float_meta.jsonclassification_tensor_float_meta.jsonclassification_tensor_uint8_meta.jsonclassification_tensor_unsupported_meta.jsonfeature_tensor_meta.jsongeneral_meta.jsongolden_json.jsonimage_classifier
BUILDlabels.txtmobilenet_v2_1.0_224.jsonmobilenet_v2_1.0_224.tflitemobilenet_v2_1.0_224_default.jsonmobilenet_v2_1.0_224_quant.jsonmobilenet_v2_1.0_224_quant.tflitescore_calibration.txt
image_segmenter
image_tensor_meta.jsoninput_audio_tesnor_default_meta.jsoninput_audio_tesnor_meta.jsoninput_image_tensor_float_meta.jsoninput_image_tensor_uint8_meta.jsoninput_image_tensor_unsupported_meta.jsoninput_text_tesnor_default_meta.jsoninput_text_tesnor_meta.jsonlabels.txtmobilenet_v2_1.0_224_quant.jsonmobilenet_v2_1.0_224_quant.tflitemobilenet_v2_1.0_224_quant_default.jsonmobilenet_v2_1.0_224_quant_dummy.jsonmobilenet_v2_1.0_224_quant_dummy_no_version.jsonmobilenet_v2_1.0_224_quant_meta_info_.jsonmulti_inputs.jsonmulti_outputs.jsonnl_classifier
object_detector
BUILDcoco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflitecoco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflitecoco_ssd_mobilenet_v1_score_calibration.jsoncoco_ssd_mobilenet_v1_score_calibration_dummy.jsonefficientdet_lite0_v1.jsonefficientdet_lite0_v1.tfliteefficientdet_lite0_v1_default.jsonlabelmap.txtscore_calibration.csvscore_calibration_dummy.csvssd_mobilenet_v1.jsonssd_mobilenet_v1.tflitessd_mobilenet_v1_default.json
question_answerer
regex_tokenizer_meta.jsonscore_calibration_file_meta.jsonscore_calibration_tensor_meta.jsonsentence_piece_tokenizer_meta.jsonodml
README.md
ios
java
image
AndroidManifest.xmlBUILDimage.pgcfg
src
com
google
android
tests
src
third_party_licenses
opensource
tools
BUILDBuild_TFLite_Support_Targets.ipynb
build_rules
BUILD
android_test
AndroidManifest_instrumentation_test_template.xmlAndroidManifest_target_stub.xmlBUILDandroid_library_instrumentation_tests.bzlandroid_multidevice_instrumentation_test.bzlgenerate_instrumentation_tests.bzlinfer_java_package_name.bzl
http_files.bzlci_build
pip_package
third_party
@ -7,7 +7,7 @@
|
||||
#include "base/trace_event/trace_event.h"
|
||||
#include "components/optimization_guide/core/model_util.h"
|
||||
#include "components/optimization_guide/core/tflite_op_resolver.h"
|
||||
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
|
||||
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h"
|
||||
|
||||
namespace optimization_guide {
|
||||
|
||||
@ -28,18 +28,21 @@ BertModelExecutor::Execute(ModelExecutionTask* execution_task,
|
||||
GetStringNameForOptimizationTarget(optimization_target_),
|
||||
"input_length", input.size());
|
||||
*out_status = ExecutionStatus::kSuccess;
|
||||
return static_cast<tflite::task::text::nlclassifier::BertNLClassifier*>(
|
||||
execution_task)
|
||||
return static_cast<tflite::task::text::BertNLClassifier*>(execution_task)
|
||||
->Classify(input);
|
||||
}
|
||||
|
||||
std::unique_ptr<BertModelExecutor::ModelExecutionTask>
|
||||
BertModelExecutor::BuildModelExecutionTask(base::MemoryMappedFile* model_file,
|
||||
ExecutionStatus* out_status) {
|
||||
tflite::task::text::BertNLClassifierOptions options;
|
||||
*options.mutable_base_options()
|
||||
->mutable_model_file()
|
||||
->mutable_file_content() = std::string(
|
||||
reinterpret_cast<const char*>(model_file->data()), model_file->length());
|
||||
auto maybe_nl_classifier =
|
||||
tflite::task::text::nlclassifier::BertNLClassifier::CreateFromBuffer(
|
||||
reinterpret_cast<const char*>(model_file->data()),
|
||||
model_file->length(), std::make_unique<TFLiteOpResolver>());
|
||||
tflite::task::text::BertNLClassifier::CreateFromOptions(
|
||||
std::move(options), std::make_unique<TFLiteOpResolver>());
|
||||
if (maybe_nl_classifier.ok())
|
||||
return std::move(maybe_nl_classifier.value());
|
||||
*out_status = ExecutionStatus::kErrorModelFileNotValid;
|
||||
|
@ -64,7 +64,12 @@ absl::Status ModelValidatorExecutor::Preprocess(
|
||||
float ModelValidatorExecutor::Postprocess(
|
||||
const std::vector<const TfLiteTensor*>& output_tensors) {
|
||||
std::vector<float> data;
|
||||
tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
|
||||
absl::Status status =
|
||||
tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
|
||||
if (!status.ok()) {
|
||||
NOTREACHED();
|
||||
return -1;
|
||||
}
|
||||
return data[0];
|
||||
}
|
||||
|
||||
|
@ -11,14 +11,15 @@ namespace optimization_guide {
|
||||
absl::Status TestTFLiteModelExecutor::Preprocess(
|
||||
const std::vector<TfLiteTensor*>& input_tensors,
|
||||
const std::vector<float>& input) {
|
||||
tflite::task::core::PopulateTensor<float>(input, input_tensors[0]);
|
||||
return absl::OkStatus();
|
||||
return tflite::task::core::PopulateTensor<float>(input, input_tensors[0]);
|
||||
}
|
||||
|
||||
std::vector<float> TestTFLiteModelExecutor::Postprocess(
|
||||
const std::vector<const TfLiteTensor*>& output_tensors) {
|
||||
std::vector<float> data;
|
||||
tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
|
||||
absl::Status status =
|
||||
tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
|
||||
DCHECK(status.ok());
|
||||
return data;
|
||||
}
|
||||
|
||||
|
@ -28,43 +28,90 @@ absl::Status PredictionModelExecutor::Preprocess(
|
||||
NOTREACHED();
|
||||
}
|
||||
|
||||
tflite::task::core::PopulateTensor<float>(
|
||||
absl::Status status = tflite::task::core::PopulateTensor<float>(
|
||||
input.client_features().client_stats().avg_deny_rate(), input_tensors[0]);
|
||||
tflite::task::core::PopulateTensor<float>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<float>(
|
||||
input.client_features().client_stats().avg_dismiss_rate(),
|
||||
input_tensors[1]);
|
||||
tflite::task::core::PopulateTensor<float>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<float>(
|
||||
input.client_features().client_stats().avg_grant_rate(),
|
||||
input_tensors[2]);
|
||||
tflite::task::core::PopulateTensor<float>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<float>(
|
||||
input.client_features().client_stats().avg_ignore_rate(),
|
||||
input_tensors[3]);
|
||||
tflite::task::core::PopulateTensor<float>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<float>(
|
||||
input.permission_features()[0].permission_stats().avg_deny_rate(),
|
||||
input_tensors[4]);
|
||||
tflite::task::core::PopulateTensor<float>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<float>(
|
||||
input.permission_features()[0].permission_stats().avg_dismiss_rate(),
|
||||
input_tensors[5]);
|
||||
tflite::task::core::PopulateTensor<float>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<float>(
|
||||
input.permission_features()[0].permission_stats().avg_grant_rate(),
|
||||
input_tensors[6]);
|
||||
tflite::task::core::PopulateTensor<float>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<float>(
|
||||
input.permission_features()[0].permission_stats().avg_ignore_rate(),
|
||||
input_tensors[7]);
|
||||
tflite::task::core::PopulateTensor<int64_t>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<int64_t>(
|
||||
static_cast<int64_t>(
|
||||
input.permission_features()[0].permission_stats().prompts_count()),
|
||||
input_tensors[8]);
|
||||
tflite::task::core::PopulateTensor<int64_t>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<int64_t>(
|
||||
static_cast<int64_t>(
|
||||
input.client_features().client_stats().prompts_count()),
|
||||
input_tensors[9]);
|
||||
tflite::task::core::PopulateTensor<int64_t>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<int64_t>(
|
||||
static_cast<int64_t>(input.client_features().gesture_enum()),
|
||||
input_tensors[10]);
|
||||
tflite::task::core::PopulateTensor<int64_t>(
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = tflite::task::core::PopulateTensor<int64_t>(
|
||||
static_cast<int64_t>(input.client_features().platform_enum()),
|
||||
input_tensors[11]);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
@ -73,7 +120,10 @@ GeneratePredictionsResponse PredictionModelExecutor::Postprocess(
|
||||
DCHECK(request_type_ == RequestType::kNotifications ||
|
||||
request_type_ == RequestType::kGeolocation);
|
||||
std::vector<float> data;
|
||||
tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
|
||||
absl::Status status =
|
||||
tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
|
||||
DCHECK(status.ok());
|
||||
|
||||
GeneratePredictionsResponse response;
|
||||
float threshold = request_type_ == RequestType::kNotifications
|
||||
? kNotificationPredictionsThreshold
|
||||
|
@ -31,8 +31,7 @@ absl::Status SegmentationModelExecutor::Preprocess(
|
||||
"length of input data does not match length of tensor");
|
||||
}
|
||||
|
||||
tflite::task::core::PopulateTensor<float>(input, input_tensors[0]);
|
||||
return absl::OkStatus();
|
||||
return tflite::task::core::PopulateTensor<float>(input, input_tensors[0]);
|
||||
}
|
||||
|
||||
float SegmentationModelExecutor::Postprocess(
|
||||
@ -43,7 +42,12 @@ float SegmentationModelExecutor::Postprocess(
|
||||
DCHECK_EQ(1u, output_tensors[0]->bytes / sizeof(output_tensors[0]->type));
|
||||
|
||||
std::vector<float> data;
|
||||
tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
|
||||
absl::Status status =
|
||||
tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
|
||||
if (!status.ok()) {
|
||||
NOTREACHED();
|
||||
return -1;
|
||||
}
|
||||
DCHECK_EQ(1u, data.size());
|
||||
return data[0];
|
||||
}
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#include "components/translate/core/language_detection/language_detection_model.h"
|
||||
|
||||
#include "base/files/memory_mapped_file.h"
|
||||
#include "base/cxx17_backports.h"
|
||||
#include "base/metrics/histogram_macros.h"
|
||||
#include "base/metrics/histogram_macros_local.h"
|
||||
#include "base/strings/utf_string_conversions.h"
|
||||
@ -73,19 +73,27 @@ void LanguageDetectionModel::UpdateWithFile(base::File model_file) {
|
||||
if (!model_file.IsValid())
|
||||
return;
|
||||
|
||||
if (!model_fb_.Initialize(std::move(model_file)))
|
||||
return;
|
||||
|
||||
recorder.set_state(
|
||||
LanguageDetectionModelState::kModelFileValidAndMemoryMapped);
|
||||
|
||||
auto statusor_classifier = tflite::task::text::nlclassifier::NLClassifier::
|
||||
CreateFromBufferAndOptions(
|
||||
reinterpret_cast<const char*>(model_fb_.data()), model_fb_.length(),
|
||||
{.input_tensor_index = 0,
|
||||
.output_score_tensor_index = 0,
|
||||
.output_label_tensor_index = 2},
|
||||
CreateLangIdResolver());
|
||||
tflite::task::text::NLClassifierOptions options;
|
||||
options.set_input_tensor_index(0);
|
||||
options.set_output_score_tensor_index(0);
|
||||
options.set_output_label_tensor_index(2);
|
||||
|
||||
std::string file_content(model_file.GetLength(), '\0');
|
||||
int bytes_read =
|
||||
model_file.Read(0, base::data(file_content), model_file.GetLength());
|
||||
if (bytes_read != model_file.GetLength()) {
|
||||
return;
|
||||
}
|
||||
*options.mutable_base_options()
|
||||
->mutable_model_file()
|
||||
->mutable_file_content() = std::move(file_content);
|
||||
|
||||
auto statusor_classifier =
|
||||
tflite::task::text::nlclassifier::NLClassifier::CreateFromOptions(
|
||||
options, CreateLangIdResolver());
|
||||
if (!statusor_classifier.ok()) {
|
||||
LOCAL_HISTOGRAM_BOOLEAN("LanguageDetection.TFLiteModel.InvalidModelFile",
|
||||
true);
|
||||
|
@ -6,7 +6,8 @@
|
||||
#define COMPONENTS_TRANSLATE_CORE_LANGUAGE_DETECTION_LANGUAGE_DETECTION_MODEL_H_
|
||||
|
||||
#include <string>
|
||||
#include "base/files/memory_mapped_file.h"
|
||||
|
||||
#include "base/files/file.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace task {
|
||||
@ -70,11 +71,6 @@ class LanguageDetectionModel {
|
||||
std::pair<std::string, float> DetectTopLanguage(
|
||||
const std::string& sampled_str) const;
|
||||
|
||||
// A memory-mapped file that contains the TFLite model used for
|
||||
// determining the language of a page. This must be valid in order
|
||||
// to evaluate the model owned by |this|.
|
||||
base::MemoryMappedFile model_fb_;
|
||||
|
||||
// The tflite classifier that can determine the language of text.
|
||||
std::unique_ptr<tflite::task::text::nlclassifier::NLClassifier>
|
||||
lang_detection_model_;
|
||||
|
@ -734,15 +734,30 @@ third_party/tensorflow-text/src/tensorflow_text/python/benchmarks/test_data/unca
|
||||
third_party/tensorflow-text/src/tensorflow_text/python/metrics 1 1
|
||||
third_party/tensorflow-text/src/tensorflow_text/python/ops/test_data 1 1
|
||||
third_party/test_fonts 15 1
|
||||
third_party/tflite_support/patches 8 2
|
||||
third_party/tflite_support/src 1 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/acceleration 3 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/c 2 2
|
||||
third_party/tflite_support/src/tensorflow_lite_support/cc/port/default 1 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/cc/task 4 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision 9 4
|
||||
third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto 4 2
|
||||
third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio 1 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/proto 2 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision 13 5
|
||||
third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/proto 12 4
|
||||
third_party/tflite_support/src/tensorflow_lite_support/cc/test/testdata/task/text 5 3
|
||||
third_party/tflite_support/src/tensorflow_lite_support/codegen 1 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece 2 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/custom_ops/kernel/sentencepiece/testdata 1 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python 2 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop 4 2
|
||||
third_party/tflite_support/src/tensorflow_lite_support/custom_ops/testdata 5 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop 8 3
|
||||
third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python 12 3
|
||||
third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops 1 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier 1 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/metadata 5 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/metadata/java/src/java/org/tensorflow/lite/support/metadata 3 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/testdata/nl_classifier 2 1
|
||||
third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image 1 1
|
||||
third_party/tlslite 1 1
|
||||
third_party/tlslite/patches 7 3
|
||||
third_party/tlslite/tlslite 3 2
|
||||
|
10
third_party/tflite/BUILD.gn
vendored
10
third_party/tflite/BUILD.gn
vendored
@ -414,8 +414,18 @@ static_library("tflite") {
|
||||
"src/tensorflow/lite/core/api/tensor_utils.cc",
|
||||
"src/tensorflow/lite/core/api/tensor_utils.h",
|
||||
"src/tensorflow/lite/core/subgraph.cc",
|
||||
"src/tensorflow/lite/delegates/interpreter_utils.cc",
|
||||
"src/tensorflow/lite/delegates/interpreter_utils.h",
|
||||
"src/tensorflow/lite/delegates/nnapi/nnapi_delegate.h",
|
||||
"src/tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc",
|
||||
"src/tensorflow/lite/experimental/acceleration/configuration/delegate_registry.cc",
|
||||
"src/tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h",
|
||||
"src/tensorflow/lite/experimental/acceleration/configuration/flatbuffer_to_proto.cc",
|
||||
"src/tensorflow/lite/experimental/acceleration/configuration/flatbuffer_to_proto.h",
|
||||
"src/tensorflow/lite/experimental/acceleration/configuration/proto_to_flatbuffer.cc",
|
||||
"src/tensorflow/lite/experimental/acceleration/configuration/proto_to_flatbuffer.h",
|
||||
"src/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.cc",
|
||||
"src/tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.h",
|
||||
"src/tensorflow/lite/experimental/resource/initialization_status.cc",
|
||||
"src/tensorflow/lite/experimental/resource/initialization_status.h",
|
||||
"src/tensorflow/lite/experimental/resource/lookup_interfaces.h",
|
||||
|
27
third_party/tflite_support/BUILD.gn
vendored
27
third_party/tflite_support/BUILD.gn
vendored
@ -16,13 +16,19 @@ config("tflite_support_config") {
|
||||
proto_library("tflite_support_proto") {
|
||||
proto_in_dir = "src"
|
||||
sources = [
|
||||
"src/tensorflow_lite_support/cc/task/core/proto/base_options.proto",
|
||||
"src/tensorflow_lite_support/cc/task/core/proto/external_file.proto",
|
||||
"src/tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.proto",
|
||||
"src/tensorflow_lite_support/cc/task/text/proto/nl_classifier_options.proto",
|
||||
"src/tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto",
|
||||
"src/tensorflow_lite_support/cc/task/vision/proto/class.proto",
|
||||
"src/tensorflow_lite_support/cc/task/vision/proto/classifications.proto",
|
||||
"src/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto",
|
||||
]
|
||||
cc_generator_options = "lite=true:"
|
||||
|
||||
import_dirs = [ "//third_party/tflite/src" ]
|
||||
proto_deps = [ "//third_party/tflite:tflite-config-proto" ]
|
||||
}
|
||||
|
||||
config("tflite_support_flags") {
|
||||
@ -31,6 +37,9 @@ config("tflite_support_flags") {
|
||||
"-Wno-extern-c-compat",
|
||||
"-Wno-implicit-function-declaration",
|
||||
"-Wno-sign-compare",
|
||||
"-Wno-ignored-attributes",
|
||||
"-Wno-deprecated-declarations",
|
||||
"-Wno-unused-variable",
|
||||
]
|
||||
if (!is_win) {
|
||||
cflags_cc = [ "-frtti" ]
|
||||
@ -51,15 +60,14 @@ static_library("tflite_support") {
|
||||
sources = [
|
||||
"src/tensorflow_lite_support/cc/common.cc",
|
||||
"src/tensorflow_lite_support/cc/common.h",
|
||||
"src/tensorflow_lite_support/cc/port/default/statusor.cc",
|
||||
"src/tensorflow_lite_support/cc/port/default/statusor.h",
|
||||
"src/tensorflow_lite_support/cc/port/default/statusor_internals.h",
|
||||
"src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc",
|
||||
"src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h",
|
||||
"src/tensorflow_lite_support/cc/port/status_macros.h",
|
||||
"src/tensorflow_lite_support/cc/port/statusor.h",
|
||||
"src/tensorflow_lite_support/cc/task/core/base_task_api.h",
|
||||
"src/tensorflow_lite_support/cc/task/core/category.h",
|
||||
"src/tensorflow_lite_support/cc/task/core/error_reporter.cc",
|
||||
"src/tensorflow_lite_support/cc/task/core/error_reporter.h",
|
||||
"src/tensorflow_lite_support/cc/task/core/external_file_handler.cc",
|
||||
"src/tensorflow_lite_support/cc/task/core/external_file_handler.h",
|
||||
"src/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h",
|
||||
@ -68,10 +76,15 @@ static_library("tflite_support") {
|
||||
"src/tensorflow_lite_support/cc/task/core/task_utils.h",
|
||||
"src/tensorflow_lite_support/cc/task/core/tflite_engine.cc",
|
||||
"src/tensorflow_lite_support/cc/task/core/tflite_engine.h",
|
||||
"src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc",
|
||||
"src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h",
|
||||
"src/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc",
|
||||
"src/tensorflow_lite_support/cc/task/processor/image_preprocessor.h",
|
||||
"src/tensorflow_lite_support/cc/task/processor/processor.cc",
|
||||
"src/tensorflow_lite_support/cc/task/processor/processor.h",
|
||||
"src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc",
|
||||
"src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h",
|
||||
"src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc",
|
||||
"src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h",
|
||||
"src/tensorflow_lite_support/cc/task/text/proto/nl_classifier_options_proto_inc.h",
|
||||
"src/tensorflow_lite_support/cc/task/vision/core/classification_head.cc",
|
||||
"src/tensorflow_lite_support/cc/task/vision/core/classification_head.h",
|
||||
"src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.cc",
|
||||
@ -126,8 +139,10 @@ static_library("tflite_support") {
|
||||
configs -= [ "//build/config/compiler:chromium_code" ]
|
||||
|
||||
configs += [
|
||||
":tflite_support_flags",
|
||||
"//build/config/compiler:no_chromium_code",
|
||||
|
||||
# Must be after no_chromium_code for warning flags to be ordered correctly.
|
||||
":tflite_support_flags",
|
||||
]
|
||||
|
||||
public_configs = [ ":tflite_support_config" ]
|
||||
|
37
third_party/tflite_support/README.chromium
vendored
37
third_party/tflite_support/README.chromium
vendored
@ -1,8 +1,8 @@
|
||||
Name: TensorFlow Lite Support
|
||||
Short Name: tflite-support
|
||||
URL: https://github.com/tensorflow/tflite-support
|
||||
Version: 3faaca9c6a3b22dec4d636b6b092431c9ac409e8
|
||||
Date: 2021/01/05
|
||||
Version: v0.3.1
|
||||
Date: 2021/12/16
|
||||
License: Apache 2.0
|
||||
License File: LICENSE
|
||||
Security Critical: Yes
|
||||
@ -13,27 +13,30 @@ TFLite Support is a toolkit that helps users to develop ML and deploy TFLite
|
||||
models onto mobile devices. It works cross-Platform and is supported on
|
||||
Java, C++ (WIP), and Swift (WIP).
|
||||
|
||||
Modifications:
|
||||
- Use chromium's logging utility in place of glog (patches/0001-use-base-logging.patch)
|
||||
- Use size_t rather than int for loops according to chromium style (0001-use-size_t.patch)
|
||||
- Rely on re::StringPiece instead of absl::string_view (0001-use_StringPiece-for_string_view.patch)
|
||||
- Remove unsafe use of conversions between NSString and string by using SysNSStringToUTF8. Note, this
|
||||
is unused code but required for presubmit checks. (0001-use-SysNSStringToUTF8.patch)
|
||||
- Remove usage of absl::Cord in tflite::support::CreateStatusWithPayload (0001-no-absl-cord.patch)
|
||||
- Use _Exit instead of _exit to work on all platforms (0001-use-exit.patch)
|
||||
- Remove external file handlers support for memory mapping files to support Windows
|
||||
(0001-remove-unsupported-memory-map-from-file-handler.patch)
|
||||
- Fixes sign compare issues in tflite-support (0001-task-utils-sign-compare.patch)
|
||||
- Remove support for sentencepiece tokenizers (0001-no-sentencepiece-tokenizer.patch)
|
||||
- Allows for the max sequence used by BERT models to be 512 instead of 128 (0001-bert-max-seq-len.patch)
|
||||
- Ensure name field in metadata exists before checking for tflite metadata (0001-add-metadata-name-check.patch)
|
||||
|
||||
Third party dependencies:
|
||||
- tflite
|
||||
- libzip
|
||||
- utf
|
||||
- tensorflow-text
|
||||
|
||||
Modifications:
|
||||
01) Use re2::StringPiece instead of absl::string_view in regex_tokenizer.cc
|
||||
02) Remove support for sentencepiece tokenization because the required overhead
|
||||
isn't worth adding this functionality, esp since no feature team needs it.
|
||||
03) [To Be Upstreamed] Remove unused functions.
|
||||
04) Remove the ABSL_DEPRECATED annotation from a deprecated struct since this
|
||||
is a no-op in chromium builds and upsets clang.
|
||||
05) [To Be Upstreamed] Use size_t in for loop in nl_classifier.h
|
||||
06) [To Be Upstreamed] Remove unused variable in task_utils.h
|
||||
07) Do not use absl::any since it is not supported in chromium
|
||||
08) [To Be Upstreamed] Remove unused stl include in tokenizer_jni_lib.h
|
||||
09) Remove unbuilt files that triggered checkdeps warnings, and fix file perms.
|
||||
10) Remove memory mapped file support in external_file_handler.cc since it is
|
||||
only available on POSIX systems.
|
||||
|
||||
Update Process:
|
||||
1) Clone the tflite-support github repo at the desired commit into src/
|
||||
2) Apply each patch listed above residing in patches/ using `git apply patches/$PATCHFILE`
|
||||
3) Get the build working.
|
||||
4) Record the patches made with `git format-patches HEAD -<number of changes>`
|
||||
|
||||
|
34
third_party/tflite_support/patches/0001-Fix-signed-comparison-in-base_vision_task_api.h.patch
vendored
34
third_party/tflite_support/patches/0001-Fix-signed-comparison-in-base_vision_task_api.h.patch
vendored
@ -1,34 +0,0 @@
|
||||
From c8bdfe3f6b3ce087c36b551d668b97101f620bdc Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Rubery <drubery@chromium.org>
|
||||
Date: Thu, 6 May 2021 11:45:48 -0700
|
||||
Subject: [PATCH] Fix signed comparison in base_vision_task_api.h
|
||||
|
||||
---
|
||||
.../cc/task/vision/core/base_vision_task_api.h | 4 ++--
|
||||
1 file changed, 2 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
|
||||
index 3d1359685f3f..c787876bec33 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
|
||||
@@ -204,7 +204,7 @@ class BaseVisionTaskApi
|
||||
if (normalization_options.num_values == 1) {
|
||||
float mean_value = normalization_options.mean_values[0];
|
||||
float inv_std_value = (1.0f / normalization_options.std_values[0]);
|
||||
- for (int i = 0; i < input_data_byte_size / sizeof(uint8);
|
||||
+ for (size_t i = 0; i < input_data_byte_size / sizeof(uint8);
|
||||
i++, input_data++, normalized_input_data++) {
|
||||
*normalized_input_data =
|
||||
inv_std_value * (static_cast<float>(*input_data) - mean_value);
|
||||
@@ -214,7 +214,7 @@ class BaseVisionTaskApi
|
||||
1.0f / normalization_options.std_values[0],
|
||||
1.0f / normalization_options.std_values[1],
|
||||
1.0f / normalization_options.std_values[2]};
|
||||
- for (int i = 0; i < input_data_byte_size / sizeof(uint8);
|
||||
+ for (size_t i = 0; i < input_data_byte_size / sizeof(uint8);
|
||||
i++, input_data++, normalized_input_data++) {
|
||||
*normalized_input_data = inv_std_values[i % 3] *
|
||||
(static_cast<float>(*input_data) -
|
||||
--
|
||||
2.31.1.607.g51e8a6a459-goog
|
||||
|
25
third_party/tflite_support/patches/0001-Remove-signed-comparison-in-frame_buffer.h.patch
vendored
25
third_party/tflite_support/patches/0001-Remove-signed-comparison-in-frame_buffer.h.patch
vendored
@ -1,25 +0,0 @@
|
||||
From 368b317061ba7deb1f42c52c5443c261bb6c03ea Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Rubery <drubery@chromium.org>
|
||||
Date: Thu, 6 May 2021 11:40:37 -0700
|
||||
Subject: [PATCH] Remove signed comparison in frame_buffer.h
|
||||
|
||||
---
|
||||
.../tensorflow_lite_support/cc/task/vision/core/frame_buffer.h | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
index 22f63fc34d36..42ac080c4749 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
@@ -246,7 +246,7 @@ class FrameBuffer {
|
||||
|
||||
// Returns plane indexed by the input `index`.
|
||||
const Plane plane(int index) const {
|
||||
- if (index > -1 && index < planes_.size()) {
|
||||
+ if (index > -1 && static_cast<size_t>(index) < planes_.size()) {
|
||||
return planes_[index];
|
||||
}
|
||||
return {};
|
||||
--
|
||||
2.31.1.607.g51e8a6a459-goog
|
||||
|
38
third_party/tflite_support/patches/0001-Remove-unused-qualifiers-in-frame_buffer.h.patch
vendored
38
third_party/tflite_support/patches/0001-Remove-unused-qualifiers-in-frame_buffer.h.patch
vendored
@ -1,38 +0,0 @@
|
||||
From b23fcde4753dbf5e4adc325e9ded16800f1d1bc5 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Rubery <drubery@chromium.org>
|
||||
Date: Thu, 6 May 2021 11:38:06 -0700
|
||||
Subject: [PATCH] Remove unused qualifiers in frame_buffer.h
|
||||
|
||||
---
|
||||
.../cc/task/vision/core/frame_buffer.h | 6 +++---
|
||||
1 file changed, 3 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
index 1556b7dfabef..22f63fc34d36 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
@@ -242,7 +242,7 @@ class FrameBuffer {
|
||||
timestamp_(timestamp) {}
|
||||
|
||||
// Returns number of planes.
|
||||
- const int plane_count() const { return planes_.size(); }
|
||||
+ int plane_count() const { return planes_.size(); }
|
||||
|
||||
// Returns plane indexed by the input `index`.
|
||||
const Plane plane(int index) const {
|
||||
@@ -256,10 +256,10 @@ class FrameBuffer {
|
||||
const Dimension dimension() const { return dimension_; }
|
||||
|
||||
// Returns FrameBuffer format.
|
||||
- const Format format() const { return format_; }
|
||||
+ Format format() const { return format_; }
|
||||
|
||||
// Returns FrameBuffer orientation.
|
||||
- const Orientation orientation() const { return orientation_; }
|
||||
+ Orientation orientation() const { return orientation_; }
|
||||
|
||||
// Returns FrameBuffer timestamp.
|
||||
const absl::Time timestamp() const { return timestamp_; }
|
||||
--
|
||||
2.31.1.607.g51e8a6a459-goog
|
||||
|
@ -1,30 +0,0 @@
|
||||
From 226d36a5d12ca3080b1d0d9b450be949e418e318 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Rubery <drubery@chromium.org>
|
||||
Date: Thu, 6 May 2021 11:26:23 -0700
|
||||
Subject: [PATCH] Use third_party/libyuv
|
||||
|
||||
---
|
||||
.../cc/task/vision/utils/libyuv_frame_buffer_utils.cc | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
|
||||
index f3cd0c70fe1b..b50b500bb5a4 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
|
||||
@@ -23,12 +23,12 @@ limitations under the License.
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
-#include "include/libyuv.h"
|
||||
#include "tensorflow_lite_support/cc/common.h"
|
||||
#include "tensorflow_lite_support/cc/port/integral_types.h"
|
||||
#include "tensorflow_lite_support/cc/port/status_macros.h"
|
||||
#include "tensorflow_lite_support/cc/port/statusor.h"
|
||||
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
|
||||
+#include "third_party/libyuv/include/libyuv.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace task {
|
||||
--
|
||||
2.31.1.607.g51e8a6a459-goog
|
||||
|
@ -1,28 +0,0 @@
|
||||
From 5e3e4b63a6bfd871afa16f8d27f2daa8b99d84e9 Mon Sep 17 00:00:00 2001
|
||||
From: mcrouse <mcrouse@google.com>
|
||||
Date: Thu, 19 Aug 2021 11:31:28 -0700
|
||||
Subject: [PATCH] add metadata name check
|
||||
|
||||
---
|
||||
.../metadata/cc/metadata_extractor.cc | 5 ++++-
|
||||
1 file changed, 4 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/third_party/tflite-support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
|
||||
index ad5df76f1c27b..42f2a7c13a516 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
|
||||
@@ -159,7 +159,10 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer(
|
||||
// Look for the "TFLITE_METADATA" field, if any.
|
||||
for (int i = 0; i < model_->metadata()->size(); ++i) {
|
||||
const auto metadata = model_->metadata()->Get(i);
|
||||
- if (metadata->name() && metadata->name()->str() != kMetadataBufferName) {
|
||||
+ if (!metadata->name()) {
|
||||
+ continue;
|
||||
+ }
|
||||
+ if (metadata->name()->str() != kMetadataBufferName) {
|
||||
continue;
|
||||
}
|
||||
const auto buffer_index = metadata->buffer();
|
||||
--
|
||||
2.33.0.rc2.250.ged5fa647cd-goog
|
||||
|
@ -1,25 +0,0 @@
|
||||
From 49cd597b3c1fbfef2e3772682aa98575654131ba Mon Sep 17 00:00:00 2001
|
||||
From: Sophie Chang <sophiechang@chromium.org>
|
||||
Date: Mon, 1 Mar 2021 19:33:21 +0000
|
||||
Subject: [PATCH] allow for more tokens
|
||||
|
||||
---
|
||||
.../cc/task/text/nlclassifier/bert_nl_classifier.h | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h
|
||||
index cd5c5a3ade03..e78085d98761 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h
|
||||
@@ -52,7 +52,7 @@ class BertNLClassifier : public NLClassifier {
|
||||
public:
|
||||
using NLClassifier::NLClassifier;
|
||||
// Max number of tokens to pass to the model.
|
||||
- static constexpr int kMaxSeqLen = 128;
|
||||
+ static constexpr int kMaxSeqLen = 512;
|
||||
|
||||
// Factory function to create a BertNLClassifier from TFLite model with
|
||||
// metadata.
|
||||
--
|
||||
2.30.1.766.gb4fecdf3b7-goog
|
||||
|
@ -1,33 +0,0 @@
|
||||
From 61fb20a08d2325d03759a5b9394c033901fc0a7f Mon Sep 17 00:00:00 2001
|
||||
From: Sophie Chang <sophiechang@chromium.org>
|
||||
Date: Wed, 3 Feb 2021 04:21:19 +0000
|
||||
Subject: [PATCH] do not use cord in tflite status payload
|
||||
|
||||
---
|
||||
.../tflite-support/src/tensorflow_lite_support/cc/common.cc | 3 ---
|
||||
1 file changed, 3 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/common.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/common.cc
|
||||
index 47dd3bcc6581..ed373e96d555 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/common.cc
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/common.cc
|
||||
@@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow_lite_support/cc/common.h"
|
||||
|
||||
-#include "absl/strings/cord.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
|
||||
namespace tflite {
|
||||
@@ -26,8 +25,6 @@ absl::Status CreateStatusWithPayload(absl::StatusCode canonical_code,
|
||||
TfLiteSupportStatus tfls_code) {
|
||||
// NOTE: Ignores `message` if the canonical code is ok.
|
||||
absl::Status status = absl::Status(canonical_code, message);
|
||||
- // NOTE: Does nothing if the canonical code is ok.
|
||||
- status.SetPayload(kTfLiteSupportPayload, absl::Cord(absl::StrCat(tfls_code)));
|
||||
return status;
|
||||
}
|
||||
|
||||
--
|
||||
2.30.0.365.g02bc693789-goog
|
||||
|
@ -1,34 +0,0 @@
|
||||
From f84b50f175efff54ee6a6ef795703907245260cd Mon Sep 17 00:00:00 2001
|
||||
From: Sophie Chang <sophiechang@chromium.org>
|
||||
Date: Wed, 10 Feb 2021 17:55:30 +0000
|
||||
Subject: [PATCH] fix sign issues
|
||||
|
||||
---
|
||||
.../src/tensorflow_lite_support/cc/task/core/task_utils.h | 4 ++--
|
||||
1 file changed, 2 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/task_utils.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/task_utils.h
|
||||
index 744dbbfb0f80..ced3dbcae9e4 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/task_utils.h
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/task_utils.h
|
||||
@@ -119,7 +119,7 @@ inline void PopulateVector(const TfLiteTensor* tensor, std::vector<T>* data) {
|
||||
const T* results = GetTensorData<T>(tensor);
|
||||
size_t num = tensor->bytes / sizeof(tensor->type);
|
||||
data->reserve(num);
|
||||
- for (int i = 0; i < num; i++) {
|
||||
+ for (size_t i = 0; i < num; i++) {
|
||||
data->emplace_back(results[i]);
|
||||
}
|
||||
}
|
||||
@@ -169,7 +169,7 @@ static TensorType* FindTensorByName(
|
||||
tensor_metadatas->size() != tensors.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
- for (int i = 0; i < tensor_metadatas->size(); i++) {
|
||||
+ for (flatbuffers::uoffset_t i = 0; i < tensor_metadatas->size(); i++) {
|
||||
if (strcmp(name.data(), tensor_metadatas->Get(i)->name()->c_str()) == 0) {
|
||||
return tensors[i];
|
||||
}
|
||||
--
|
||||
2.30.0.478.g8a0d178c01-goog
|
||||
|
@ -1,38 +0,0 @@
|
||||
From 81287a62d65139f29c512fed88ed734bef2c33f5 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Crouse <mcrouse@chromium.org>
|
||||
Date: Tue, 22 Dec 2020 14:25:39 -0800
|
||||
Subject: [PATCH] use StringPiece for string_view
|
||||
|
||||
---
|
||||
.../cc/text/tokenizers/regex_tokenizer.cc | 10 +++++-----
|
||||
1 file changed, 5 insertions(+), 5 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
|
||||
index 38aff8805b30..44c43b2d5086 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
|
||||
@@ -61,16 +61,16 @@ RegexTokenizer::RegexTokenizer(const std::string& regex_pattern,
|
||||
}
|
||||
|
||||
TokenizerResult RegexTokenizer::Tokenize(const std::string& input) {
|
||||
- absl::string_view leftover(input.data());
|
||||
- absl::string_view last_end = leftover;
|
||||
+ re2::StringPiece leftover(input.data());
|
||||
+ re2::StringPiece last_end = leftover;
|
||||
|
||||
TokenizerResult result;
|
||||
|
||||
// Keep looking for split points until we have reached the end of the input.
|
||||
- absl::string_view extracted_delim_token;
|
||||
+ re2::StringPiece extracted_delim_token;
|
||||
while (RE2::FindAndConsume(&leftover, delim_re_, &extracted_delim_token)) {
|
||||
- absl::string_view token(last_end.data(),
|
||||
- extracted_delim_token.data() - last_end.data());
|
||||
+ re2::StringPiece token(last_end.data(),
|
||||
+ extracted_delim_token.data() - last_end.data());
|
||||
bool has_non_empty_token = token.length() > 0;
|
||||
|
||||
last_end = leftover;
|
||||
--
|
||||
2.29.2.729.g45daf8777d-goog
|
||||
|
@ -1,29 +0,0 @@
|
||||
From e4b8790a56487279b084fb59a2186a8bfd24b838 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Crouse <mcrouse@chromium.org>
|
||||
Date: Thu, 7 Jan 2021 08:20:06 -0800
|
||||
Subject: [PATCH] use SysNSStringToUTF8
|
||||
|
||||
---
|
||||
.../tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm | 3 ++-
|
||||
1 file changed, 2 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/third_party/tflite-support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
|
||||
index 2a11bb673047..b82be34a9ab9 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
|
||||
@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
+#inmport "base/strings/sys_string_conversions.h"
|
||||
#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"
|
||||
|
||||
std::string MakeString(NSString* str) {
|
||||
- return std::string([str UTF8String]);
|
||||
+ return SysNSStringToUTF8(str);
|
||||
}
|
||||
|
||||
NSString* MakeNSString(const std::string& str) {
|
||||
--
|
||||
2.29.2.729.g45daf8777d-goog
|
||||
|
@ -1,26 +0,0 @@
|
||||
From 5307d81798215dae084b5079e797fd4408040340 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Crouse <mcrouse@chromium.org>
|
||||
Date: Tue, 22 Dec 2020 14:18:09 -0800
|
||||
Subject: [PATCH] use base logging
|
||||
|
||||
---
|
||||
.../src/tensorflow_lite_support/cc/port/default/statusor.cc | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc
|
||||
index 547a79192324..182c37e4aaf6 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc
|
||||
@@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
|
||||
-#include <glog/logging.h>
|
||||
#include "absl/strings/str_cat.h"
|
||||
+#include "base/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
--
|
||||
2.29.2.729.g45daf8777d-goog
|
||||
|
@ -1,25 +0,0 @@
|
||||
From 49de34d7489ba5218a822461a42786844a1e344b Mon Sep 17 00:00:00 2001
|
||||
From: Sophie Chang <sophiechang@chromium.org>
|
||||
Date: Wed, 3 Feb 2021 04:30:56 +0000
|
||||
Subject: [PATCH] use _Exit
|
||||
|
||||
---
|
||||
.../src/tensorflow_lite_support/cc/port/default/statusor.cc | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc
|
||||
index 182c37e4aaf6..058c0070f0da 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/port/default/statusor.cc
|
||||
@@ -50,7 +50,7 @@ void Helper::HandleInvalidStatusCtorArg(absl::Status* status) {
|
||||
void Helper::Crash(const absl::Status& status) {
|
||||
LOG(FATAL) << "Attempting to fetch value instead of handling error "
|
||||
<< status;
|
||||
- _exit(1);
|
||||
+ _Exit(1);
|
||||
}
|
||||
|
||||
void ThrowBadStatusOrAccess(absl::Status status) {
|
||||
--
|
||||
2.30.0.365.g02bc693789-goog
|
||||
|
36
third_party/tflite_support/patches/0001-use-re2-StringPiece-for-RegexTokenizer-Tokenize.patch
vendored
Normal file
36
third_party/tflite_support/patches/0001-use-re2-StringPiece-for-RegexTokenizer-Tokenize.patch
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
From b16b7af6f58ede0718fabf9c0da7495c79400c90 Mon Sep 17 00:00:00 2001
|
||||
From: Robert Ogden <robertogden@chromium.org>
|
||||
Date: Wed, 15 Dec 2021 14:39:48 -0800
|
||||
Subject: [PATCH 01/11] use re2 StringPiece for RegexTokenizer::Tokenize
|
||||
|
||||
---
|
||||
.../cc/text/tokenizers/regex_tokenizer.cc | 8 ++++----
|
||||
1 file changed, 4 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
|
||||
index 564f5f63a0584..832f9df42f824 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
|
||||
@@ -61,15 +61,15 @@ RegexTokenizer::RegexTokenizer(const std::string& regex_pattern,
|
||||
}
|
||||
|
||||
TokenizerResult RegexTokenizer::Tokenize(const std::string& input) {
|
||||
- absl::string_view leftover(input.data());
|
||||
- absl::string_view last_end = leftover;
|
||||
+ re2::StringPiece leftover(input.data());
|
||||
+ re2::StringPiece last_end = leftover;
|
||||
|
||||
TokenizerResult result;
|
||||
|
||||
// Keep looking for split points until we have reached the end of the input.
|
||||
- absl::string_view extracted_delim_token;
|
||||
+ re2::StringPiece extracted_delim_token;
|
||||
while (RE2::FindAndConsume(&leftover, delim_re_, &extracted_delim_token)) {
|
||||
- absl::string_view token(last_end.data(),
|
||||
+ re2::StringPiece token(last_end.data(),
|
||||
extracted_delim_token.data() - last_end.data());
|
||||
bool has_non_empty_token = token.length() > 0;
|
||||
|
||||
--
|
||||
2.34.1.307.g9b7440fafd-goog
|
||||
|
@ -1,25 +0,0 @@
|
||||
From ecb535154168358a72de6b51099a9549b970bce5 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Crouse <mcrouse@chromium.org>
|
||||
Date: Tue, 22 Dec 2020 14:34:12 -0800
|
||||
Subject: [PATCH] use size_t
|
||||
|
||||
---
|
||||
.../cc/task/text/nlclassifier/nl_classifier.h | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
|
||||
index d3c3bed2083d..4055a93467d4 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
|
||||
@@ -151,7 +151,7 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
|
||||
const std::string& name,
|
||||
int index) {
|
||||
if (metadata_array != nullptr && metadata_array->size() == tensors.size()) {
|
||||
- for (int i = 0; i < metadata_array->size(); i++) {
|
||||
+ for (size_t i = 0; i < metadata_array->size(); i++) {
|
||||
if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) {
|
||||
return tensors[i];
|
||||
}
|
||||
--
|
||||
2.29.2.729.g45daf8777d-goog
|
||||
|
@ -1,16 +1,16 @@
|
||||
From 7faac3ddcbc05275d797dda64a9b9d7f2279ae1c Mon Sep 17 00:00:00 2001
|
||||
From: Sophie Chang <sophiechang@chromium.org>
|
||||
Date: Thu, 11 Feb 2021 00:53:47 +0000
|
||||
Subject: [PATCH] no sentencepiece tokenizer
|
||||
From 9aa45ea43f8d84db1e20674c294f9ab958b12d7e Mon Sep 17 00:00:00 2001
|
||||
From: Robert Ogden <robertogden@chromium.org>
|
||||
Date: Wed, 15 Dec 2021 14:57:16 -0800
|
||||
Subject: [PATCH 02/11] sentencepiece tokenization not supported
|
||||
|
||||
---
|
||||
.../cc/text/tokenizers/tokenizer_utils.cc | 11 -----------
|
||||
1 file changed, 11 deletions(-)
|
||||
.../cc/text/tokenizers/tokenizer_utils.cc | 14 ++++----------
|
||||
1 file changed, 4 insertions(+), 10 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
|
||||
index 352c4a8c5e4f..46786fd7faf8 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
|
||||
index 9abca9691f058..28f0137f54278 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
|
||||
@@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "tensorflow_lite_support/cc/port/status_macros.h"
|
||||
#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
|
||||
@ -19,11 +19,18 @@ index 352c4a8c5e4f..46786fd7faf8 100644
|
||||
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
@@ -73,16 +72,6 @@ StatusOr<std::unique_ptr<Tokenizer>> CreateTokenizerFromProcessUnit(
|
||||
return absl::make_unique<BertTokenizer>(vocab_buffer.data(),
|
||||
@@ -29,7 +28,6 @@ namespace text {
|
||||
namespace tokenizer {
|
||||
|
||||
using ::tflite::ProcessUnit;
|
||||
-using ::tflite::SentencePieceTokenizerOptions;
|
||||
using ::tflite::support::CreateStatusWithPayload;
|
||||
using ::tflite::support::StatusOr;
|
||||
using ::tflite::support::TfLiteSupportStatus;
|
||||
@@ -74,14 +72,10 @@ StatusOr<std::unique_ptr<Tokenizer>> CreateTokenizerFromProcessUnit(
|
||||
vocab_buffer.size());
|
||||
}
|
||||
- case ProcessUnitOptions_SentencePieceTokenizerOptions: {
|
||||
case ProcessUnitOptions_SentencePieceTokenizerOptions: {
|
||||
- const tflite::SentencePieceTokenizerOptions* options =
|
||||
- tokenizer_process_unit->options_as<SentencePieceTokenizerOptions>();
|
||||
- ASSIGN_OR_RETURN(absl::string_view model_buffer,
|
||||
@ -32,9 +39,13 @@ index 352c4a8c5e4f..46786fd7faf8 100644
|
||||
- // TODO(b/160647204): Extract sentence piece model vocabulary
|
||||
- return absl::make_unique<SentencePieceTokenizer>(model_buffer.data(),
|
||||
- model_buffer.size());
|
||||
- }
|
||||
+ return CreateStatusWithPayload(
|
||||
+ absl::StatusCode::kInvalidArgument,
|
||||
+ "Chromium does not support sentencepiece tokenization",
|
||||
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
|
||||
}
|
||||
case ProcessUnitOptions_RegexTokenizerOptions: {
|
||||
const tflite::RegexTokenizerOptions* options =
|
||||
tokenizer_process_unit->options_as<RegexTokenizerOptions>();
|
||||
--
|
||||
2.30.0.478.g8a0d178c01-goog
|
||||
2.34.1.307.g9b7440fafd-goog
|
||||
|
224
third_party/tflite_support/patches/0003-rm-unused-func.patch
vendored
Normal file
224
third_party/tflite_support/patches/0003-rm-unused-func.patch
vendored
Normal file
@ -0,0 +1,224 @@
|
||||
From 7bd2e5f0e2bb560e55efc3dd86249ff42a10d08c Mon Sep 17 00:00:00 2001
|
||||
From: Robert Ogden <robertogden@chromium.org>
|
||||
Date: Wed, 15 Dec 2021 15:08:59 -0800
|
||||
Subject: [PATCH 03/11] rm unused func
|
||||
|
||||
---
|
||||
.../vision/utils/libyuv_frame_buffer_utils.cc | 201 +-----------------
|
||||
1 file changed, 1 insertion(+), 200 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
|
||||
index 0ece48636504e..6fd3ca81c984c 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
|
||||
@@ -1326,206 +1326,7 @@ absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1,
|
||||
}
|
||||
}
|
||||
|
||||
-// Returns the scaled dimension of the input_size best fit within the
|
||||
-// output_size bound while respecting the aspect ratio.
|
||||
-FrameBuffer::Dimension GetScaledDimension(FrameBuffer::Dimension input_size,
|
||||
- FrameBuffer::Dimension output_size) {
|
||||
- int original_width = input_size.width;
|
||||
- int original_height = input_size.height;
|
||||
- int bound_width = output_size.width;
|
||||
- int bound_height = output_size.height;
|
||||
- int new_width = original_width;
|
||||
- int new_height = original_height;
|
||||
-
|
||||
- // Try to fit the width first.
|
||||
- new_width = bound_width;
|
||||
- new_height = (new_width * original_height) / original_width;
|
||||
-
|
||||
- // Try to fit the height if needed.
|
||||
- if (new_height > bound_height) {
|
||||
- new_height = bound_height;
|
||||
- new_width = (new_height * original_width) / original_height;
|
||||
- }
|
||||
- return FrameBuffer::Dimension{.width = new_width, .height = new_height};
|
||||
-}
|
||||
-
|
||||
-// This method only supports kGRAY, kRGBA, and kRGB formats.
|
||||
-absl::Status UniformCropResizePlane(const FrameBuffer& buffer,
|
||||
- std::vector<int> crop_coordinates,
|
||||
- FrameBuffer* output_buffer) {
|
||||
- int x0 = 0, y0 = 0;
|
||||
- FrameBuffer::Dimension input_dimension = buffer.dimension();
|
||||
- if (!crop_coordinates.empty()) {
|
||||
- x0 = crop_coordinates[0];
|
||||
- y0 = crop_coordinates[1];
|
||||
- input_dimension =
|
||||
- GetCropDimension(x0, crop_coordinates[2], y0, crop_coordinates[3]);
|
||||
- }
|
||||
- if (input_dimension == output_buffer->dimension()) {
|
||||
- // Cropping only case.
|
||||
- return CropPlane(buffer, x0, y0, crop_coordinates[2], crop_coordinates[3],
|
||||
- output_buffer);
|
||||
- }
|
||||
-
|
||||
- // Cropping is achieved by adjusting origin to (x0, y0).
|
||||
- ASSIGN_OR_RETURN(int pixel_stride, GetPixelStrides(buffer.format()));
|
||||
- int adjusted_offset =
|
||||
- buffer.plane(0).stride.row_stride_bytes * y0 + x0 * pixel_stride;
|
||||
- FrameBuffer::Plane plane = {
|
||||
- /*buffer=*/buffer.plane(0).buffer + adjusted_offset,
|
||||
- /*stride=*/{buffer.plane(0).stride.row_stride_bytes, pixel_stride}};
|
||||
- auto adjusted_buffer =
|
||||
- FrameBuffer::Create({plane}, input_dimension, buffer.format(),
|
||||
- buffer.orientation(), buffer.timestamp());
|
||||
-
|
||||
- // Uniform resize is achieved by adjusting the resize dimension to fit the
|
||||
- // output_buffer and respect the input aspect ratio at the same time. We
|
||||
- // create an intermediate output buffer with adjusted dimension and point its
|
||||
- // backing buffer to the output_buffer. Note the stride information on the
|
||||
- // adjusted_output_buffer is not used in the Resize* methods.
|
||||
- FrameBuffer::Dimension adjusted_dimension =
|
||||
- GetScaledDimension(input_dimension, output_buffer->dimension());
|
||||
- FrameBuffer::Plane output_plane = {/*buffer=*/output_buffer->plane(0).buffer,
|
||||
- /*stride=*/output_buffer->plane(0).stride};
|
||||
- auto adjusted_output_buffer = FrameBuffer::Create(
|
||||
- {output_plane}, adjusted_dimension, output_buffer->format(),
|
||||
- output_buffer->orientation(), output_buffer->timestamp());
|
||||
-
|
||||
- switch (buffer.format()) {
|
||||
- case FrameBuffer::Format::kRGB:
|
||||
- return ResizeRgb(*adjusted_buffer, adjusted_output_buffer.get());
|
||||
- case FrameBuffer::Format::kRGBA:
|
||||
- return ResizeRgba(*adjusted_buffer, adjusted_output_buffer.get());
|
||||
- case FrameBuffer::Format::kGRAY:
|
||||
- return ResizeGray(*adjusted_buffer, adjusted_output_buffer.get());
|
||||
- default:
|
||||
- return CreateStatusWithPayload(
|
||||
- StatusCode::kInternal,
|
||||
- absl::StrFormat("Format %i is not supported.", buffer.format()),
|
||||
- TfLiteSupportStatus::kImageProcessingError);
|
||||
- }
|
||||
-}
|
||||
-
|
||||
-absl::Status UniformCropResizeYuv(const FrameBuffer& buffer,
|
||||
- std::vector<int> crop_coordinates,
|
||||
- FrameBuffer* output_buffer) {
|
||||
- int x0 = 0, y0 = 0;
|
||||
- FrameBuffer::Dimension input_dimension = buffer.dimension();
|
||||
- if (!crop_coordinates.empty()) {
|
||||
- x0 = crop_coordinates[0];
|
||||
- y0 = crop_coordinates[1];
|
||||
- input_dimension =
|
||||
- GetCropDimension(x0, crop_coordinates[2], y0, crop_coordinates[3]);
|
||||
- }
|
||||
- if (input_dimension == output_buffer->dimension()) {
|
||||
- // Cropping only case.
|
||||
- int x1 = crop_coordinates[2];
|
||||
- int y1 = crop_coordinates[3];
|
||||
- switch (buffer.format()) {
|
||||
- case FrameBuffer::Format::kNV12:
|
||||
- case FrameBuffer::Format::kNV21:
|
||||
- return CropNv(buffer, x0, y0, x1, y1, output_buffer);
|
||||
- case FrameBuffer::Format::kYV12:
|
||||
- case FrameBuffer::Format::kYV21:
|
||||
- return CropYv(buffer, x0, y0, x1, y1, output_buffer);
|
||||
- default:
|
||||
- return CreateStatusWithPayload(
|
||||
- StatusCode::kInternal,
|
||||
- absl::StrFormat("Format %i is not supported.", buffer.format()),
|
||||
- TfLiteSupportStatus::kImageProcessingError);
|
||||
- }
|
||||
- }
|
||||
-
|
||||
- // Cropping is achieved by adjusting origin to (x0, y0).
|
||||
- ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
|
||||
- FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
|
||||
- // Cropping YUV planes by offsetting the origins of each plane.
|
||||
- // TODO(b/152629712): Investigate the impact of color shifting caused by the
|
||||
- // bounding box with odd X or Y starting positions.
|
||||
- const int plane_y_offset = input_data.y_row_stride * y0 + x0;
|
||||
- const int plane_uv_offset = input_data.uv_row_stride * (y0 / 2) +
|
||||
- input_data.uv_pixel_stride * (x0 / 2);
|
||||
- FrameBuffer::Plane adjusted_plane_y = {
|
||||
- /*buffer=*/input_data.y_buffer + plane_y_offset,
|
||||
- /*stride=*/{input_data.y_row_stride, /*pixel_stride_bytes=*/1}};
|
||||
- FrameBuffer::Plane adjusted_plane_u = {
|
||||
- /*buffer=*/input_data.u_buffer + plane_uv_offset,
|
||||
- /*stride=*/{input_data.uv_row_stride, input_data.uv_pixel_stride}};
|
||||
- FrameBuffer::Plane adjusted_plane_v = {
|
||||
- /*buffer=*/input_data.v_buffer + plane_uv_offset,
|
||||
- /*stride=*/{input_data.uv_row_stride, input_data.uv_pixel_stride}};
|
||||
-
|
||||
- // Uniform resize is achieved by adjusting the resize dimension to fit the
|
||||
- // output_buffer and respect the input aspect ratio at the same time. For
|
||||
- // YUV formats, we need access to the actual output dimension to get the
|
||||
- // correct address of each plane. For this, we are not calling ResizeNv or
|
||||
- // ResizeYv but the libyuv scale methods directly.
|
||||
- FrameBuffer::Dimension adjusted_dimension =
|
||||
- GetScaledDimension(input_dimension, output_buffer->dimension());
|
||||
- ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
|
||||
- FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
|
||||
-
|
||||
- switch (buffer.format()) {
|
||||
- case FrameBuffer::Format::kNV12: {
|
||||
- int ret = libyuv::NV12Scale(
|
||||
- adjusted_plane_y.buffer, adjusted_plane_y.stride.row_stride_bytes,
|
||||
- adjusted_plane_u.buffer, adjusted_plane_u.stride.row_stride_bytes,
|
||||
- input_dimension.width, input_dimension.height,
|
||||
- const_cast<uint8_t*>(output_data.y_buffer), output_data.y_row_stride,
|
||||
- const_cast<uint8_t*>(output_data.u_buffer), output_data.uv_row_stride,
|
||||
- adjusted_dimension.width, adjusted_dimension.height,
|
||||
- libyuv::FilterMode::kFilterBilinear);
|
||||
- if (ret != 0) {
|
||||
- return CreateStatusWithPayload(
|
||||
- StatusCode::kUnknown, "Libyuv NV12Scale operation failed.",
|
||||
- TfLiteSupportStatus::kImageProcessingBackendError);
|
||||
- }
|
||||
- return absl::OkStatus();
|
||||
- }
|
||||
- case FrameBuffer::Format::kNV21: {
|
||||
- int ret = libyuv::NV12Scale(
|
||||
- adjusted_plane_y.buffer, adjusted_plane_y.stride.row_stride_bytes,
|
||||
- adjusted_plane_v.buffer, adjusted_plane_v.stride.row_stride_bytes,
|
||||
- input_dimension.width, input_dimension.height,
|
||||
- const_cast<uint8_t*>(output_data.y_buffer), output_data.y_row_stride,
|
||||
- const_cast<uint8_t*>(output_data.v_buffer), output_data.uv_row_stride,
|
||||
- adjusted_dimension.width, adjusted_dimension.height,
|
||||
- libyuv::FilterMode::kFilterBilinear);
|
||||
- if (ret != 0) {
|
||||
- return CreateStatusWithPayload(
|
||||
- StatusCode::kUnknown, "Libyuv NV12Scale operation failed.",
|
||||
- TfLiteSupportStatus::kImageProcessingBackendError);
|
||||
- }
|
||||
- return absl::OkStatus();
|
||||
- }
|
||||
- case FrameBuffer::Format::kYV12:
|
||||
- case FrameBuffer::Format::kYV21: {
|
||||
- int ret = libyuv::I420Scale(
|
||||
- adjusted_plane_y.buffer, adjusted_plane_y.stride.row_stride_bytes,
|
||||
- adjusted_plane_u.buffer, adjusted_plane_u.stride.row_stride_bytes,
|
||||
- adjusted_plane_v.buffer, adjusted_plane_v.stride.row_stride_bytes,
|
||||
- input_dimension.width, input_dimension.height,
|
||||
- const_cast<uint8_t*>(output_data.y_buffer), output_data.y_row_stride,
|
||||
- const_cast<uint8_t*>(output_data.u_buffer), output_data.uv_row_stride,
|
||||
- const_cast<uint8_t*>(output_data.v_buffer), output_data.uv_row_stride,
|
||||
- adjusted_dimension.width, adjusted_dimension.height,
|
||||
- libyuv::FilterMode::kFilterBilinear);
|
||||
- if (ret != 0) {
|
||||
- return CreateStatusWithPayload(
|
||||
- StatusCode::kUnknown, "Libyuv I420Scale operation failed.",
|
||||
- TfLiteSupportStatus::kImageProcessingBackendError);
|
||||
- }
|
||||
- return absl::OkStatus();
|
||||
- }
|
||||
- default:
|
||||
- return CreateStatusWithPayload(
|
||||
- StatusCode::kInternal,
|
||||
- absl::StrFormat("Format %i is not supported.", buffer.format()),
|
||||
- TfLiteSupportStatus::kImageProcessingError);
|
||||
- }
|
||||
- return absl::OkStatus();
|
||||
-}
|
||||
-} // namespace
|
||||
+} // namespace
|
||||
|
||||
absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer, int x0,
|
||||
int y0, int x1, int y1,
|
||||
--
|
||||
2.34.1.307.g9b7440fafd-goog
|
||||
|
26
third_party/tflite_support/patches/0004-rm-noop-deprecated-attribute.patch
vendored
Normal file
26
third_party/tflite_support/patches/0004-rm-noop-deprecated-attribute.patch
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
From 243aadd7dcea9be980aa89d183bf2dea7cba202b Mon Sep 17 00:00:00 2001
|
||||
From: Robert Ogden <robertogden@chromium.org>
|
||||
Date: Wed, 15 Dec 2021 15:49:40 -0800
|
||||
Subject: [PATCH 04/11] rm noop deprecated attribute
|
||||
|
||||
---
|
||||
.../cc/task/text/nlclassifier/nl_classifier.h | 3 ---
|
||||
1 file changed, 3 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
|
||||
index d5b49dfd75277..ac12536355db4 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
|
||||
@@ -43,9 +43,6 @@ namespace text {
|
||||
namespace nlclassifier {
|
||||
|
||||
// Options to identify input and output tensors of the model
|
||||
-ABSL_DEPRECATED(
|
||||
- "Prefer using `tflite::task::text::NLClassifierOptions` and "
|
||||
- "`CreateFromOptions`")
|
||||
struct NLClassifierOptions {
|
||||
int input_tensor_index = 0;
|
||||
int output_score_tensor_index = 0;
|
||||
--
|
||||
2.34.1.307.g9b7440fafd-goog
|
||||
|
25
third_party/tflite_support/patches/0005-use-size_t-in-for-loop.patch
vendored
Normal file
25
third_party/tflite_support/patches/0005-use-size_t-in-for-loop.patch
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
From 5d67933d8d4440816a02f7a319d7323041c3f7bf Mon Sep 17 00:00:00 2001
|
||||
From: Robert Ogden <robertogden@chromium.org>
|
||||
Date: Wed, 15 Dec 2021 15:51:22 -0800
|
||||
Subject: [PATCH 05/11] use size_t in for loop
|
||||
|
||||
---
|
||||
.../cc/task/text/nlclassifier/nl_classifier.h | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
|
||||
index ac12536355db4..2adafba8f2fa9 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
|
||||
@@ -179,7 +179,7 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
|
||||
metadata_array,
|
||||
const std::string& name, int index) {
|
||||
if (metadata_array != nullptr && metadata_array->size() == tensors.size()) {
|
||||
- for (int i = 0; i < metadata_array->size(); i++) {
|
||||
+ for (size_t i = 0; i < metadata_array->size(); i++) {
|
||||
if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) {
|
||||
return tensors[i];
|
||||
}
|
||||
--
|
||||
2.34.1.307.g9b7440fafd-goog
|
||||
|
29
third_party/tflite_support/patches/0006-unused-variable.patch
vendored
Normal file
29
third_party/tflite_support/patches/0006-unused-variable.patch
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
From b2d06daf8ab5cff8748489407b6ad10ea600948d Mon Sep 17 00:00:00 2001
|
||||
From: Robert Ogden <robertogden@chromium.org>
|
||||
Date: Thu, 16 Dec 2021 08:35:07 -0800
|
||||
Subject: [PATCH 06/11] unused variable
|
||||
|
||||
---
|
||||
.../src/tensorflow_lite_support/cc/task/core/task_utils.h | 6 ++++--
|
||||
1 file changed, 4 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
|
||||
index 03ef9ade9af41..e95ea73a4a812 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h
|
||||
@@ -144,8 +144,10 @@ inline absl::Status PopulateVector(const TfLiteTensor* tensor,
|
||||
template <>
|
||||
inline absl::Status PopulateVector<std::string>(
|
||||
const TfLiteTensor* tensor, std::vector<std::string>* data) {
|
||||
- std::string* v;
|
||||
- ASSIGN_OR_RETURN(v, AssertAndReturnTypedTensor<std::string>(tensor));
|
||||
+ if (tensor->type != typeToTfLiteType<std::string>()) {
|
||||
+ return absl::InvalidArgumentError("not of type string");
|
||||
+ }
|
||||
+
|
||||
int num = GetStringCount(tensor);
|
||||
data->reserve(num);
|
||||
for (int i = 0; i < num; i++) {
|
||||
--
|
||||
2.34.1.307.g9b7440fafd-goog
|
||||
|
@ -1,25 +1,25 @@
|
||||
From 670dfffa386fd0ff28e66cfe1238af43b4e587ce Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Rubery <drubery@chromium.org>
|
||||
Date: Thu, 6 May 2021 11:22:13 -0700
|
||||
Subject: [PATCH] Remove use of banned absl::any
|
||||
From d3d4385132632282fc91c735875ebfc90697b067 Mon Sep 17 00:00:00 2001
|
||||
From: Robert Ogden <robertogden@chromium.org>
|
||||
Date: Thu, 16 Dec 2021 13:28:16 -0800
|
||||
Subject: [PATCH 07/11] do not use absl any
|
||||
|
||||
---
|
||||
.../cc/task/vision/core/frame_buffer.h | 27 -------------------
|
||||
1 file changed, 27 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
index 2bea92883c4d..1556b7dfabef 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
index c1289673cb82b..1668447393e9e 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
|
||||
@@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/time/clock.h"
|
||||
#include "absl/time/time.h"
|
||||
-#include "absl/types/any.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/strings/str_cat.h" // from @com_google_absl
|
||||
#include "absl/time/clock.h" // from @com_google_absl
|
||||
#include "absl/time/time.h" // from @com_google_absl
|
||||
-#include "absl/types/any.h" // from @com_google_absl
|
||||
#include "absl/types/optional.h" // from @com_google_absl
|
||||
#include "tensorflow_lite_support/cc/port/integral_types.h"
|
||||
#include "tensorflow_lite_support/cc/port/statusor.h"
|
||||
@@ -253,31 +252,6 @@ class FrameBuffer {
|
||||
@@ -250,31 +249,6 @@ class FrameBuffer {
|
||||
return {};
|
||||
}
|
||||
|
||||
@ -51,7 +51,7 @@ index 2bea92883c4d..1556b7dfabef 100644
|
||||
// Returns FrameBuffer dimension.
|
||||
const Dimension dimension() const { return dimension_; }
|
||||
|
||||
@@ -292,7 +266,6 @@ class FrameBuffer {
|
||||
@@ -289,7 +263,6 @@ class FrameBuffer {
|
||||
|
||||
private:
|
||||
std::vector<Plane> planes_;
|
||||
@ -60,5 +60,5 @@ index 2bea92883c4d..1556b7dfabef 100644
|
||||
Format format_;
|
||||
Orientation orientation_;
|
||||
--
|
||||
2.31.1.607.g51e8a6a459-goog
|
||||
2.34.1.307.g9b7440fafd-goog
|
||||
|
25
third_party/tflite_support/patches/0008-unused-string-include.patch
vendored
Normal file
25
third_party/tflite_support/patches/0008-unused-string-include.patch
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
From 5feffc2cdd8c970490fadd812401be4eb57174d5 Mon Sep 17 00:00:00 2001
|
||||
From: Robert Ogden <robertogden@chromium.org>
|
||||
Date: Thu, 16 Dec 2021 13:43:57 -0800
|
||||
Subject: [PATCH 08/11] unused string include
|
||||
|
||||
---
|
||||
.../cc/text/tokenizers/tokenizer_jni_lib.h | 2 --
|
||||
1 file changed, 2 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
|
||||
index fc7285c6807b0..33677d305a853 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
|
||||
@@ -17,8 +17,6 @@ limitations under the License.
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
-#include <string>
|
||||
-
|
||||
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
|
||||
#include "tensorflow_lite_support/cc/utils/jni_utils.h"
|
||||
|
||||
--
|
||||
2.34.1.307.g9b7440fafd-goog
|
||||
|
666
third_party/tflite_support/patches/0009-remove-unbuilt-files-and-change-exec-bit-where-neede.patch
vendored
Normal file
666
third_party/tflite_support/patches/0009-remove-unbuilt-files-and-change-exec-bit-where-neede.patch
vendored
Normal file
@ -0,0 +1,666 @@
|
||||
From 515f1ef8496e5c73318aa41f6295bbfbefb6bbae Mon Sep 17 00:00:00 2001
|
||||
From: Robert Ogden <robertogden@chromium.org>
|
||||
Date: Thu, 16 Dec 2021 13:56:14 -0800
|
||||
Subject: [PATCH 09/11] remove unbuilt files and change exec bit where needed
|
||||
|
||||
---
|
||||
.../cc/port/benchmark.h | 21 ---
|
||||
.../cc/port/default/status_matchers.h | 55 -------
|
||||
.../tensorflow_lite_support/cc/port/gmock.h | 21 ---
|
||||
.../tensorflow_lite_support/cc/port/gtest.h | 21 ---
|
||||
.../tensorflow_lite_support/cc/port/proto2.h | 32 ----
|
||||
.../examples/task/audio/desktop/python/BUILD | 0
|
||||
.../task/audio/desktop/python/README.md | 0
|
||||
.../desktop/python/audio_classifier_demo.py | 0
|
||||
.../examples/task/vision/desktop/python/BUILD | 0
|
||||
.../desktop/python/image_classifier_demo.py | 0
|
||||
.../desktop/python/image_segmenter_demo.py | 0
|
||||
.../desktop/python/object_detector_demo.py | 0
|
||||
.../ios/utils/Sources/TFLStringUtil.mm | 23 ---
|
||||
.../metadata/cc/metadata_populator.cc | 150 ------------------
|
||||
.../metadata/cc/utils/zip_mem_file.cc | 124 ---------------
|
||||
.../metadata/cc/utils/zip_mem_file.h | 71 ---------
|
||||
.../odml/ios/image/resources/grace_hopper.jpg | Bin
|
||||
.../tools/ci_build/build_all.sh | 0
|
||||
.../ci_build/builds/build_ios_framework.sh | 0
|
||||
.../tools/ci_build/builds/pip_smoke_test.sh | 0
|
||||
.../tools/ci_build/common.sh | 0
|
||||
.../tools/ci_build/common_win.bat | 0
|
||||
22 files changed, 518 deletions(-)
|
||||
delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h
|
||||
delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h
|
||||
delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h
|
||||
delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h
|
||||
delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h
|
||||
mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/BUILD
|
||||
mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/README.md
|
||||
mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/audio_classifier_demo.py
|
||||
mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/BUILD
|
||||
mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_classifier_demo.py
|
||||
mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_segmenter_demo.py
|
||||
mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/object_detector_demo.py
|
||||
delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
|
||||
delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc
|
||||
delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc
|
||||
delete mode 100644 third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h
|
||||
mode change 100755 => 100644 third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/resources/grace_hopper.jpg
|
||||
mode change 100644 => 100755 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh
|
||||
mode change 100644 => 100755 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh
|
||||
mode change 100644 => 100755 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh
|
||||
mode change 100644 => 100755 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common.sh
|
||||
mode change 100644 => 100755 third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat
|
||||
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h
|
||||
deleted file mode 100644
|
||||
index 74bc1a6857664..0000000000000
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/benchmark.h
|
||||
+++ /dev/null
|
||||
@@ -1,21 +0,0 @@
|
||||
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
-
|
||||
-Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-you may not use this file except in compliance with the License.
|
||||
-You may obtain a copy of the License at
|
||||
-
|
||||
- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-
|
||||
-Unless required by applicable law or agreed to in writing, software
|
||||
-distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-See the License for the specific language governing permissions and
|
||||
-limitations under the License.
|
||||
-==============================================================================*/
|
||||
-
|
||||
-#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_
|
||||
-#define TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_
|
||||
-
|
||||
-#include "gtest/benchmark.h"
|
||||
-
|
||||
-#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h
|
||||
deleted file mode 100644
|
||||
index 6d9668043c183..0000000000000
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/status_matchers.h
|
||||
+++ /dev/null
|
||||
@@ -1,55 +0,0 @@
|
||||
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
-
|
||||
-Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-you may not use this file except in compliance with the License.
|
||||
-You may obtain a copy of the License at
|
||||
-
|
||||
- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-
|
||||
-Unless required by applicable law or agreed to in writing, software
|
||||
-distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-See the License for the specific language governing permissions and
|
||||
-limitations under the License.
|
||||
-==============================================================================*/
|
||||
-
|
||||
-#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_
|
||||
-#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_
|
||||
-
|
||||
-#include "gmock/gmock.h"
|
||||
-#include "gtest/gtest.h"
|
||||
-
|
||||
-#define SUPPORT_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
|
||||
-#define SUPPORT_STATUS_MACROS_IMPL_CONCAT_(x, y) \
|
||||
- SUPPORT_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y)
|
||||
-
|
||||
-#undef SUPPORT_ASSERT_OK
|
||||
-#define SUPPORT_ASSERT_OK(expr) \
|
||||
- SUPPORT_ASSERT_OK_IMPL_( \
|
||||
- SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), expr)
|
||||
-
|
||||
-#define SUPPORT_ASSERT_OK_IMPL_(status, expr) \
|
||||
- auto status = (expr); \
|
||||
- ASSERT_TRUE(status.ok());
|
||||
-
|
||||
-#undef SUPPORT_EXPECT_OK
|
||||
-#define SUPPORT_EXPECT_OK(expr) \
|
||||
- SUPPORT_EXPECT_OK_IMPL_( \
|
||||
- SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), expr)
|
||||
-
|
||||
-#define SUPPORT_EXPECT_OK_IMPL_(status, expr) \
|
||||
- auto status = (expr); \
|
||||
- EXPECT_TRUE(status.ok());
|
||||
-
|
||||
-#undef SUPPORT_ASSERT_OK_AND_ASSIGN
|
||||
-#define SUPPORT_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
|
||||
- SUPPORT_ASSERT_OK_AND_ASSIGN_IMPL_( \
|
||||
- SUPPORT_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, \
|
||||
- rexpr)
|
||||
-
|
||||
-#define SUPPORT_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, rexpr) \
|
||||
- auto statusor = (rexpr); \
|
||||
- ASSERT_TRUE(statusor.ok()); \
|
||||
- lhs = std::move(statusor.value())
|
||||
-
|
||||
-#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MATCHERS_H_
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h
|
||||
deleted file mode 100644
|
||||
index 5e4334db323d6..0000000000000
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gmock.h
|
||||
+++ /dev/null
|
||||
@@ -1,21 +0,0 @@
|
||||
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
-
|
||||
-Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-you may not use this file except in compliance with the License.
|
||||
-You may obtain a copy of the License at
|
||||
-
|
||||
- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-
|
||||
-Unless required by applicable law or agreed to in writing, software
|
||||
-distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-See the License for the specific language governing permissions and
|
||||
-limitations under the License.
|
||||
-==============================================================================*/
|
||||
-
|
||||
-#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_
|
||||
-#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_
|
||||
-
|
||||
-#include "gmock/gmock.h"
|
||||
-
|
||||
-#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h
|
||||
deleted file mode 100644
|
||||
index dbe2e5e6f9d7c..0000000000000
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/gtest.h
|
||||
+++ /dev/null
|
||||
@@ -1,21 +0,0 @@
|
||||
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
-
|
||||
-Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-you may not use this file except in compliance with the License.
|
||||
-You may obtain a copy of the License at
|
||||
-
|
||||
- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-
|
||||
-Unless required by applicable law or agreed to in writing, software
|
||||
-distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-See the License for the specific language governing permissions and
|
||||
-limitations under the License.
|
||||
-==============================================================================*/
|
||||
-
|
||||
-#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_
|
||||
-#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_
|
||||
-
|
||||
-#include "gtest/gtest.h"
|
||||
-
|
||||
-#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h
|
||||
deleted file mode 100644
|
||||
index 3cde2ab81d6ee..0000000000000
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/proto2.h
|
||||
+++ /dev/null
|
||||
@@ -1,32 +0,0 @@
|
||||
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
-
|
||||
-Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-you may not use this file except in compliance with the License.
|
||||
-You may obtain a copy of the License at
|
||||
-
|
||||
- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-
|
||||
-Unless required by applicable law or agreed to in writing, software
|
||||
-distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-See the License for the specific language governing permissions and
|
||||
-limitations under the License.
|
||||
-==============================================================================*/
|
||||
-#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_
|
||||
-#define TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_
|
||||
-
|
||||
-#include "google/protobuf/message_lite.h"
|
||||
-#include "google/protobuf/text_format.h"
|
||||
-
|
||||
-namespace tflite {
|
||||
-namespace support {
|
||||
-namespace proto {
|
||||
-
|
||||
-using TextFormat = ::google::protobuf::TextFormat;
|
||||
-using MessageLite = ::google::protobuf::MessageLite;
|
||||
-
|
||||
-} // namespace proto
|
||||
-} // namespace support
|
||||
-} // namespace tflite
|
||||
-
|
||||
-#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_PROTO_NS_H_
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/BUILD
|
||||
old mode 100755
|
||||
new mode 100644
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/README.md b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/README.md
|
||||
old mode 100755
|
||||
new mode 100644
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/audio_classifier_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/audio/desktop/python/audio_classifier_demo.py
|
||||
old mode 100755
|
||||
new mode 100644
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/BUILD b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/BUILD
|
||||
old mode 100755
|
||||
new mode 100644
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_classifier_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_classifier_demo.py
|
||||
old mode 100755
|
||||
new mode 100644
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_segmenter_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/image_segmenter_demo.py
|
||||
old mode 100755
|
||||
new mode 100644
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/object_detector_demo.py b/third_party/tflite_support/src/tensorflow_lite_support/examples/task/vision/desktop/python/object_detector_demo.py
|
||||
old mode 100755
|
||||
new mode 100644
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
|
||||
deleted file mode 100644
|
||||
index 6e9cf23802427..0000000000000
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
|
||||
+++ /dev/null
|
||||
@@ -1,23 +0,0 @@
|
||||
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
-
|
||||
-Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-you may not use this file except in compliance with the License.
|
||||
-You may obtain a copy of the License at
|
||||
-
|
||||
- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-
|
||||
-Unless required by applicable law or agreed to in writing, software
|
||||
-distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-See the License for the specific language governing permissions and
|
||||
-limitations under the License.
|
||||
-==============================================================================*/
|
||||
-#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"
|
||||
-
|
||||
-std::string MakeString(NSString* str) { return std::string([str UTF8String]); }
|
||||
-
|
||||
-NSString* MakeNSString(const std::string& str) {
|
||||
- return [[NSString alloc] initWithBytes:const_cast<void*>(static_cast<const void*>(str.data()))
|
||||
- length:str.length()
|
||||
- encoding:NSUTF8StringEncoding];
|
||||
-}
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc
|
||||
deleted file mode 100644
|
||||
index e21d426369e2e..0000000000000
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/metadata_populator.cc
|
||||
+++ /dev/null
|
||||
@@ -1,150 +0,0 @@
|
||||
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
-
|
||||
-Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-you may not use this file except in compliance with the License.
|
||||
-You may obtain a copy of the License at
|
||||
-
|
||||
- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-
|
||||
-Unless required by applicable law or agreed to in writing, software
|
||||
-distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-See the License for the specific language governing permissions and
|
||||
-limitations under the License.
|
||||
-==============================================================================*/
|
||||
-
|
||||
-#include "tensorflow_lite_support/metadata/cc/metadata_populator.h"
|
||||
-
|
||||
-#include <cstdlib>
|
||||
-#include <cstring>
|
||||
-#include <functional>
|
||||
-
|
||||
-#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
-#include "contrib/minizip/ioapi.h"
|
||||
-#include "contrib/minizip/zip.h"
|
||||
-#include "tensorflow/lite/schema/schema_generated.h"
|
||||
-#include "tensorflow_lite_support/cc/common.h"
|
||||
-#include "tensorflow_lite_support/cc/port/status_macros.h"
|
||||
-#include "tensorflow_lite_support/cc/port/statusor.h"
|
||||
-#include "tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h"
|
||||
-#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
|
||||
-
|
||||
-namespace tflite {
|
||||
-namespace metadata {
|
||||
-
|
||||
-namespace {
|
||||
-constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
|
||||
-
|
||||
-using ::absl::StatusCode;
|
||||
-using ::tflite::support::CreateStatusWithPayload;
|
||||
-using ::tflite::support::TfLiteSupportStatus;
|
||||
-
|
||||
-} // namespace
|
||||
-
|
||||
-ModelMetadataPopulator::ModelMetadataPopulator(const tflite::Model& model) {
|
||||
- model.UnPackTo(&model_t_);
|
||||
-}
|
||||
-
|
||||
-/* static */
|
||||
-tflite::support::StatusOr<std::unique_ptr<ModelMetadataPopulator>>
|
||||
-ModelMetadataPopulator::CreateFromModelBuffer(const char* buffer_data,
|
||||
- size_t buffer_size) {
|
||||
- // Rely on the simplest, base flatbuffers verifier. Here is not the place to
|
||||
- // e.g. use an OpResolver: we just want to make sure the buffer is valid to
|
||||
- // access the metadata.
|
||||
- flatbuffers::Verifier verifier = flatbuffers::Verifier(
|
||||
- reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
|
||||
- if (!tflite::VerifyModelBuffer(verifier)) {
|
||||
- return CreateStatusWithPayload(
|
||||
- StatusCode::kInvalidArgument,
|
||||
- "The model is not a valid FlatBuffer buffer.",
|
||||
- TfLiteSupportStatus::kInvalidFlatBufferError);
|
||||
- }
|
||||
- // Use absl::WrapUnique() to call private constructor:
|
||||
- // https://abseil.io/tips/126.
|
||||
- return absl::WrapUnique(
|
||||
- new ModelMetadataPopulator(*tflite::GetModel(buffer_data)));
|
||||
-}
|
||||
-
|
||||
-void ModelMetadataPopulator::LoadMetadata(const char* metadata_buffer_data,
|
||||
- size_t metadata_buffer_size) {
|
||||
- // Pack the model metadata in a buffer.
|
||||
- auto model_metadata_buffer = std::make_unique<tflite::BufferT>();
|
||||
- model_metadata_buffer->data = {metadata_buffer_data,
|
||||
- metadata_buffer_data + metadata_buffer_size};
|
||||
- // Check if the model already has metadata. If so, just override the buffer
|
||||
- // and exit.
|
||||
- for (const auto& metadata_t : model_t_.metadata) {
|
||||
- if (metadata_t->name == kMetadataBufferName) {
|
||||
- model_t_.buffers[metadata_t->buffer] = std::move(model_metadata_buffer);
|
||||
- return;
|
||||
- }
|
||||
- }
|
||||
- // Model doesn't already have metadata: add metadata buffer and pointer to the
|
||||
- // buffer in the model metadata section.
|
||||
- model_t_.buffers.push_back(std::move(model_metadata_buffer));
|
||||
- auto metadata_t = std::make_unique<tflite::MetadataT>();
|
||||
- metadata_t->name = kMetadataBufferName;
|
||||
- metadata_t->buffer = model_t_.buffers.size() - 1;
|
||||
- model_t_.metadata.push_back(std::move(metadata_t));
|
||||
-}
|
||||
-
|
||||
-void ModelMetadataPopulator::LoadAssociatedFiles(
|
||||
- const absl::flat_hash_map<std::string, std::string>& associated_files) {
|
||||
- associated_files_ = associated_files;
|
||||
-}
|
||||
-
|
||||
-tflite::support::StatusOr<std::string>
|
||||
-ModelMetadataPopulator::AppendAssociatedFiles(const char* model_buffer_data,
|
||||
- size_t model_buffer_size) {
|
||||
- // Create in-memory zip file.
|
||||
- ZipMemFile mem_file = ZipMemFile(model_buffer_data, model_buffer_size);
|
||||
- // Open zip.
|
||||
- zipFile zf = zipOpen2(/*pathname=*/nullptr, APPEND_STATUS_CREATEAFTER,
|
||||
- /*globalcomment=*/nullptr, &mem_file.GetFileFuncDef());
|
||||
- if (zf == nullptr) {
|
||||
- return CreateStatusWithPayload(
|
||||
- StatusCode::kUnknown, "Unable to open zip archive",
|
||||
- TfLiteSupportStatus::kMetadataAssociatedFileZipError);
|
||||
- }
|
||||
- // Write associated files.
|
||||
- for (const auto& [name, contents] : associated_files_) {
|
||||
- if ((zipOpenNewFileInZip(zf, name.c_str(),
|
||||
- /*zipfi=*/nullptr,
|
||||
- /*extrafield_local=*/nullptr,
|
||||
- /*size_extrafield_local=*/0,
|
||||
- /*extrafield_global=*/nullptr,
|
||||
- /*size_extrafield_global=*/0,
|
||||
- /*comment=*/nullptr,
|
||||
- /*method=*/0,
|
||||
- /*level=*/Z_DEFAULT_COMPRESSION) != ZIP_OK) ||
|
||||
- (zipWriteInFileInZip(zf, contents.data(), contents.length()) !=
|
||||
- ZIP_OK) ||
|
||||
- (zipCloseFileInZip(zf) != ZIP_OK)) {
|
||||
- return CreateStatusWithPayload(
|
||||
- StatusCode::kUnknown, "Unable to write file to zip archive",
|
||||
- TfLiteSupportStatus::kMetadataAssociatedFileZipError);
|
||||
- }
|
||||
- }
|
||||
- // Close zip.
|
||||
- if (zipClose(zf, /*global_comment=*/nullptr) != ZIP_OK) {
|
||||
- return CreateStatusWithPayload(
|
||||
- StatusCode::kUnknown, "Unable to close zip archive",
|
||||
- TfLiteSupportStatus::kMetadataAssociatedFileZipError);
|
||||
- }
|
||||
- // Return as a string.
|
||||
- return std::string(mem_file.GetFileContent());
|
||||
-}
|
||||
-
|
||||
-tflite::support::StatusOr<std::string> ModelMetadataPopulator::Populate() {
|
||||
- // Build model.
|
||||
- flatbuffers::FlatBufferBuilder model_fbb;
|
||||
- model_fbb.Finish(tflite::Model::Pack(model_fbb, &model_t_),
|
||||
- tflite::ModelIdentifier());
|
||||
- return AppendAssociatedFiles(
|
||||
- reinterpret_cast<char*>(model_fbb.GetBufferPointer()),
|
||||
- model_fbb.GetSize());
|
||||
-}
|
||||
-
|
||||
-} // namespace metadata
|
||||
-} // namespace tflite
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc
|
||||
deleted file mode 100644
|
||||
index 2e4d9107c8c31..0000000000000
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.cc
|
||||
+++ /dev/null
|
||||
@@ -1,124 +0,0 @@
|
||||
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
-
|
||||
-Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-you may not use this file except in compliance with the License.
|
||||
-You may obtain a copy of the License at
|
||||
-
|
||||
- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-
|
||||
-Unless required by applicable law or agreed to in writing, software
|
||||
-distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-See the License for the specific language governing permissions and
|
||||
-limitations under the License.
|
||||
-==============================================================================*/
|
||||
-
|
||||
-#include "tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h"
|
||||
-
|
||||
-#include <algorithm>
|
||||
-#include <cstdio>
|
||||
-
|
||||
-#include "absl/strings/string_view.h" // from @com_google_absl
|
||||
-#include "contrib/minizip/ioapi.h"
|
||||
-
|
||||
-namespace tflite {
|
||||
-namespace metadata {
|
||||
-
|
||||
-ZipMemFile::ZipMemFile(const char* buffer, size_t size)
|
||||
- : data_(buffer, size), offset_(0) {
|
||||
- zlib_filefunc_def_.zopen_file = OpenFile;
|
||||
- zlib_filefunc_def_.zread_file = ReadFile;
|
||||
- zlib_filefunc_def_.zwrite_file = WriteFile;
|
||||
- zlib_filefunc_def_.ztell_file = TellFile;
|
||||
- zlib_filefunc_def_.zseek_file = SeekFile;
|
||||
- zlib_filefunc_def_.zclose_file = CloseFile;
|
||||
- zlib_filefunc_def_.zerror_file = ErrorFile;
|
||||
- zlib_filefunc_def_.opaque = this;
|
||||
-}
|
||||
-
|
||||
-zlib_filefunc_def& ZipMemFile::GetFileFuncDef() { return zlib_filefunc_def_; }
|
||||
-
|
||||
-absl::string_view ZipMemFile::GetFileContent() const { return data_; }
|
||||
-
|
||||
-/* static */
|
||||
-voidpf ZipMemFile::OpenFile(voidpf opaque, const char* filename, int mode) {
|
||||
- // Result is never used, but needs to be non-null for `zipOpen2` not to fail.
|
||||
- return opaque;
|
||||
-}
|
||||
-
|
||||
-/* static */
|
||||
-size_t ZipMemFile::ReadFile(voidpf opaque, voidpf stream, void* buf,
|
||||
- size_t size) {
|
||||
- auto* mem_file = static_cast<ZipMemFile*>(opaque);
|
||||
- if (mem_file->offset_ < 0 || mem_file->Size() < mem_file->offset_) {
|
||||
- return 0;
|
||||
- }
|
||||
- if (mem_file->offset_ + size > mem_file->Size()) {
|
||||
- size = mem_file->Size() - mem_file->offset_;
|
||||
- }
|
||||
- memcpy(buf,
|
||||
- static_cast<const char*>(mem_file->data_.c_str()) + mem_file->offset_,
|
||||
- size);
|
||||
- mem_file->offset_ += size;
|
||||
- return size;
|
||||
-}
|
||||
-
|
||||
-/* static */
|
||||
-size_t ZipMemFile::WriteFile(voidpf opaque, voidpf stream, const void* buf,
|
||||
- size_t size) {
|
||||
- auto* mem_file = static_cast<ZipMemFile*>(opaque);
|
||||
- if (mem_file->offset_ + size > mem_file->Size()) {
|
||||
- mem_file->data_.resize(mem_file->offset_ + size);
|
||||
- }
|
||||
- mem_file->data_.replace(mem_file->offset_, size,
|
||||
- static_cast<const char*>(buf), size);
|
||||
- mem_file->offset_ += size;
|
||||
- return size;
|
||||
-}
|
||||
-
|
||||
-/* static */
|
||||
-ptrdiff_t ZipMemFile::TellFile(voidpf opaque, voidpf stream) {
|
||||
- return static_cast<ZipMemFile*>(opaque)->offset_;
|
||||
-}
|
||||
-
|
||||
-/* static */
|
||||
-ptrdiff_t ZipMemFile::SeekFile(voidpf opaque, voidpf stream, size_t offset,
|
||||
- int origin) {
|
||||
- auto* mem_file = static_cast<ZipMemFile*>(opaque);
|
||||
- switch (origin) {
|
||||
- case SEEK_SET:
|
||||
- mem_file->offset_ = offset;
|
||||
- return 0;
|
||||
- case SEEK_CUR:
|
||||
- if (mem_file->offset_ + offset < 0 ||
|
||||
- mem_file->offset_ + offset > mem_file->Size()) {
|
||||
- return -1;
|
||||
- }
|
||||
- mem_file->offset_ += offset;
|
||||
- return 0;
|
||||
- case SEEK_END:
|
||||
- if (mem_file->Size() - offset < 0 ||
|
||||
- mem_file->Size() - offset > mem_file->Size()) {
|
||||
- return -1;
|
||||
- }
|
||||
- mem_file->offset_ = offset + mem_file->Size();
|
||||
- return 0;
|
||||
- default:
|
||||
- return -1;
|
||||
- }
|
||||
-}
|
||||
-
|
||||
-/* static */
|
||||
-int ZipMemFile::CloseFile(voidpf opaque, voidpf stream) {
|
||||
- // Nothing to do.
|
||||
- return 0;
|
||||
-}
|
||||
-
|
||||
-/* static */
|
||||
-int ZipMemFile::ErrorFile(voidpf opaque, voidpf stream) {
|
||||
- // Unused.
|
||||
- return 0;
|
||||
-}
|
||||
-
|
||||
-} // namespace metadata
|
||||
-} // namespace tflite
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h b/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h
|
||||
deleted file mode 100644
|
||||
index ef7843d70cff6..0000000000000
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/metadata/cc/utils/zip_mem_file.h
|
||||
+++ /dev/null
|
||||
@@ -1,71 +0,0 @@
|
||||
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
-
|
||||
-Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-you may not use this file except in compliance with the License.
|
||||
-You may obtain a copy of the License at
|
||||
-
|
||||
- http://www.apache.org/licenses/LICENSE-2.0
|
||||
-
|
||||
-Unless required by applicable law or agreed to in writing, software
|
||||
-distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-See the License for the specific language governing permissions and
|
||||
-limitations under the License.
|
||||
-==============================================================================*/
|
||||
-
|
||||
-#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_
|
||||
-#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_
|
||||
-
|
||||
-#include <cstdlib>
|
||||
-
|
||||
-#include "absl/strings/string_view.h" // from @com_google_absl
|
||||
-#include "contrib/minizip/ioapi.h"
|
||||
-
|
||||
-namespace tflite {
|
||||
-namespace metadata {
|
||||
-
|
||||
-// In-memory zip file implementation.
|
||||
-//
|
||||
-// Adapted from [1], with a few key differences:
|
||||
-// * backed by an `std::string` instead of malloc-ed C buffers,
|
||||
-// * supports opening the file for writing through `zipOpen2`.
|
||||
-//
|
||||
-// [1]:
|
||||
-// https://github.com/google/libkml/blob/master/third_party/zlib-1.2.3/contrib/minizip/iomem_simple.c
|
||||
-class ZipMemFile {
|
||||
- public:
|
||||
- // Constructs an in-memory zip file from a buffer.
|
||||
- ZipMemFile(const char* buffer, size_t size);
|
||||
- // Provides access to the `zlib_filefunc_def` implementation for the in-memory
|
||||
- // zip file.
|
||||
- zlib_filefunc_def& GetFileFuncDef();
|
||||
- // Provides access to the file contents.
|
||||
- absl::string_view GetFileContent() const;
|
||||
-
|
||||
- private:
|
||||
- // The string backing the in-memory file.
|
||||
- std::string data_;
|
||||
- // The current offset in the file.
|
||||
- size_t offset_;
|
||||
- // The `zlib_filefunc_def` implementation for this in-memory zip file.
|
||||
- zlib_filefunc_def zlib_filefunc_def_;
|
||||
-
|
||||
- // Convenience function to access the current data size.
|
||||
- size_t Size() const { return data_.size(); }
|
||||
-
|
||||
- // The file function implementations used in the `zlib_filefunc_def`.
|
||||
- static voidpf OpenFile(voidpf opaque, const char* filename, int mode);
|
||||
- static size_t ReadFile(voidpf opaque, voidpf stream, void* buf, size_t size);
|
||||
- static size_t WriteFile(voidpf opaque, voidpf stream, const void* buf,
|
||||
- size_t size);
|
||||
- static ptrdiff_t TellFile(voidpf opaque, voidpf stream);
|
||||
- static ptrdiff_t SeekFile(voidpf opaque, voidpf stream, size_t offset,
|
||||
- int origin);
|
||||
- static int CloseFile(voidpf opaque, voidpf stream);
|
||||
- static int ErrorFile(voidpf opaque, voidpf stream);
|
||||
-};
|
||||
-
|
||||
-} // namespace metadata
|
||||
-} // namespace tflite
|
||||
-
|
||||
-#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_UTILS_ZIP_MEM_FILE_H_
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/resources/grace_hopper.jpg b/third_party/tflite_support/src/tensorflow_lite_support/odml/ios/image/resources/grace_hopper.jpg
|
||||
old mode 100755
|
||||
new mode 100644
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/build_all.sh
|
||||
old mode 100644
|
||||
new mode 100755
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/build_ios_framework.sh
|
||||
old mode 100644
|
||||
new mode 100755
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh
|
||||
old mode 100644
|
||||
new mode 100755
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common.sh b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common.sh
|
||||
old mode 100644
|
||||
new mode 100755
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat b/third_party/tflite_support/src/tensorflow_lite_support/tools/ci_build/common_win.bat
|
||||
old mode 100644
|
||||
new mode 100755
|
||||
--
|
||||
2.34.1.307.g9b7440fafd-goog
|
||||
|
@ -1,22 +1,25 @@
|
||||
From 8c5a37f7324b4a03f123c6faecedb4abc8eb0066 Mon Sep 17 00:00:00 2001
|
||||
From: mcrouse <mcrouse@google.com>
|
||||
Date: Fri, 5 Feb 2021 15:30:25 +0000
|
||||
Subject: [PATCH] remove unsupported memory map from file handler
|
||||
From 4b7e971e2f2f6ef3fd394858d975b64479047872 Mon Sep 17 00:00:00 2001
|
||||
From: Robert Ogden <robertogden@chromium.org>
|
||||
Date: Mon, 20 Dec 2021 08:50:35 -0800
|
||||
Subject: [PATCH 10/11] only support model file passed in from mem
|
||||
|
||||
---
|
||||
.../cc/task/core/external_file_handler.cc | 126 +-----------------
|
||||
.../cc/task/core/external_file_handler.h | 7 -
|
||||
.../cc/task/core/external_file_handler.cc | 143 ++----------------
|
||||
.../cc/task/core/external_file_handler.h | 23 +--
|
||||
.../cc/task/core/tflite_engine.cc | 2 -
|
||||
.../cc/task/core/tflite_engine.h | 2 -
|
||||
4 files changed, 6 insertions(+), 131 deletions(-)
|
||||
4 files changed, 10 insertions(+), 160 deletions(-)
|
||||
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
|
||||
index ee689e41c6e5..55b662f0926f 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
|
||||
@@ -18,9 +18,6 @@ limitations under the License.
|
||||
#include <errno.h>
|
||||
#include <fcntl.h>
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
|
||||
index dcde0c926c653..e91a54fb7d11a 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.cc
|
||||
@@ -15,45 +15,25 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
|
||||
|
||||
-#include <errno.h>
|
||||
-#include <fcntl.h>
|
||||
#include <stddef.h>
|
||||
-#include <sys/mman.h>
|
||||
-#include <unistd.h>
|
||||
@ -24,7 +27,21 @@ index ee689e41c6e5..55b662f0926f 100644
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
@@ -40,18 +37,6 @@ using ::tflite::support::CreateStatusWithPayload;
|
||||
-#include "absl/memory/memory.h" // from @com_google_absl
|
||||
+#include "absl/memory/memory.h" // from @com_google_absl
|
||||
#include "absl/strings/str_format.h" // from @com_google_absl
|
||||
#include "tensorflow_lite_support/cc/common.h"
|
||||
-#include "tensorflow_lite_support/cc/port/statusor.h"
|
||||
#include "tensorflow_lite_support/cc/port/status_macros.h"
|
||||
+#include "tensorflow_lite_support/cc/port/statusor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace task {
|
||||
namespace core {
|
||||
-namespace {
|
||||
|
||||
using ::absl::StatusCode;
|
||||
using ::tflite::support::CreateStatusWithPayload;
|
||||
using ::tflite::support::StatusOr;
|
||||
using ::tflite::support::TfLiteSupportStatus;
|
||||
|
||||
@ -40,10 +57,12 @@ index ee689e41c6e5..55b662f0926f 100644
|
||||
- return aligned_offset;
|
||||
-}
|
||||
-
|
||||
} // namespace
|
||||
|
||||
-} // namespace
|
||||
-
|
||||
/* static */
|
||||
@@ -71,103 +56,11 @@ absl::Status ExternalFileHandler::MapExternalFile() {
|
||||
StatusOr<std::unique_ptr<ExternalFileHandler>>
|
||||
ExternalFileHandler::CreateFromExternalFile(const ExternalFile* external_file) {
|
||||
@@ -71,123 +51,18 @@ absl::Status ExternalFileHandler::MapExternalFile() {
|
||||
if (!external_file_.file_content().empty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
@ -144,16 +163,22 @@ index ee689e41c6e5..55b662f0926f 100644
|
||||
- TfLiteSupportStatus::kFileMmapError);
|
||||
- }
|
||||
- return absl::OkStatus();
|
||||
+ return CreateStatusWithPayload(
|
||||
+ StatusCode::kInvalidArgument,
|
||||
+ "ExternalFile must have 'file_content' set, loading from"
|
||||
+ "'file_name' is not supported.",
|
||||
+ TfLiteSupportStatus::kInvalidArgumentError);
|
||||
+
|
||||
+ return CreateStatusWithPayload(StatusCode::kInvalidArgument,
|
||||
+ "ExternalFile must specify 'file_content' "
|
||||
+ "to be compatible with Chromium.",
|
||||
+ TfLiteSupportStatus::kInvalidArgumentError);
|
||||
}
|
||||
|
||||
absl::string_view ExternalFileHandler::GetFileContent() {
|
||||
@@ -180,14 +73,7 @@ absl::string_view ExternalFileHandler::GetFileContent() {
|
||||
}
|
||||
- if (!external_file_.file_content().empty()) {
|
||||
- return external_file_.file_content();
|
||||
- } else {
|
||||
- return absl::string_view(static_cast<const char*>(buffer_) +
|
||||
- buffer_offset_ - buffer_aligned_offset_,
|
||||
- buffer_size_);
|
||||
- }
|
||||
+ return external_file_.file_content();
|
||||
}
|
||||
|
||||
-ExternalFileHandler::~ExternalFileHandler() {
|
||||
@ -168,48 +193,64 @@ index ee689e41c6e5..55b662f0926f 100644
|
||||
|
||||
} // namespace core
|
||||
} // namespace task
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
|
||||
index 236d90347698..ad292dcc3702 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
|
||||
@@ -65,10 +65,6 @@ class ExternalFileHandler {
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
|
||||
index cf0bdf0b48037..48c62813e212e 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/external_file_handler.h
|
||||
@@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
-#include "absl/status/status.h" // from @com_google_absl
|
||||
+#include "absl/status/status.h" // from @com_google_absl
|
||||
#include "absl/strings/string_view.h" // from @com_google_absl
|
||||
#include "tensorflow_lite_support/cc/port/integral_types.h"
|
||||
#include "tensorflow_lite_support/cc/port/statusor.h"
|
||||
@@ -64,27 +64,6 @@ class ExternalFileHandler {
|
||||
|
||||
// Reference to the input ExternalFile.
|
||||
const ExternalFile& external_file_;
|
||||
|
||||
-
|
||||
- // The file descriptor of the ExternalFile if provided by path, as it is
|
||||
- // opened and owned by this class. Set to -1 otherwise.
|
||||
- int owned_fd_{-1};
|
||||
-
|
||||
// Points to the memory buffer mapped from the file descriptor of the
|
||||
// ExternalFile, if provided by path or file descriptor.
|
||||
void* buffer_{};
|
||||
@@ -82,9 +78,6 @@ class ExternalFileHandler {
|
||||
|
||||
// The aligned mapped memory buffer offset, if any.
|
||||
int64 buffer_aligned_offset_{};
|
||||
- // Points to the memory buffer mapped from the file descriptor of the
|
||||
- // ExternalFile, if provided by path or file descriptor.
|
||||
- void* buffer_{};
|
||||
-
|
||||
- // The mapped memory buffer offset, if any.
|
||||
- int64 buffer_offset_{};
|
||||
- // The size in bytes of the mapped memory buffer, if any.
|
||||
- int64 buffer_size_{};
|
||||
-
|
||||
- // As mmap(2) requires the offset to be a multiple of sysconf(_SC_PAGE_SIZE):
|
||||
-
|
||||
- // The aligned mapped memory buffer offset, if any.
|
||||
- int64 buffer_aligned_offset_{};
|
||||
- // The aligned mapped memory buffer size in bytes taking into account the
|
||||
- // offset shift introduced by buffer_aligned_memory_offset_, if any.
|
||||
- int64 buffer_aligned_size_{};
|
||||
};
|
||||
|
||||
} // namespace core
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
|
||||
index 0317d5f8ea34..6230e5c645c0 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
|
||||
index 8cd4585161df7..484b9a099ecdc 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.cc
|
||||
@@ -15,8 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
|
||||
|
||||
-#include <unistd.h>
|
||||
-
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
diff --git a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
|
||||
index 6a7e97dd264e..bc55f6b0fe72 100644
|
||||
--- a/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
|
||||
+++ b/third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/match.h" // from @com_google_absl
|
||||
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
|
||||
index 9b44c6e5c022a..53dabdc4841d7 100644
|
||||
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
|
||||
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/tflite_engine.h
|
||||
@@ -16,8 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
|
||||
@ -218,7 +259,7 @@ index 6a7e97dd264e..bc55f6b0fe72 100644
|
||||
-
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/memory/memory.h" // from @com_google_absl
|
||||
--
|
||||
2.30.0.365.g02bc693789-goog
|
||||
2.34.1.307.g9b7440fafd-goog
|
||||
|
43843
third_party/tflite_support/patches/0011-run-clang-format.patch
vendored
Normal file
43843
third_party/tflite_support/patches/0011-run-clang-format.patch
vendored
Normal file
File diff suppressed because it is too large
Load Diff
170
third_party/tflite_support/src/.bazelrc
vendored
Normal file
170
third_party/tflite_support/src/.bazelrc
vendored
Normal file
@ -0,0 +1,170 @@
|
||||
# This file is based on tensorflow's (v2.2.0) .bazelrc found here:
|
||||
# https://github.com/tensorflow/tensorflow/blob/v2.2.0/.bazelrc
|
||||
|
||||
# Sets the default Apple platform to macOS.
|
||||
build:macos --apple_platform_type=macos
|
||||
|
||||
# Flag to enable remote config. Required starting from TF 2.2.
|
||||
common --experimental_repo_remote_exec
|
||||
|
||||
# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1
|
||||
build --java_toolchain=//third_party/toolchains/java:tf_java_toolchain
|
||||
build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
|
||||
|
||||
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
||||
build:android --copt=-w
|
||||
build:linux --copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
|
||||
# Android workspace configurations. Should be replaced by an interative configure in the future.
|
||||
build --action_env ANDROID_NDK_HOME
|
||||
build --action_env ANDROID_NDK_API_LEVEL
|
||||
build --action_env ANDROID_BUILD_TOOLS_VERSION
|
||||
build --action_env ANDROID_SDK_API_LEVEL
|
||||
build --action_env ANDROID_SDK_HOME
|
||||
|
||||
# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the
|
||||
# target CPU to build transient dependencies correctly. See
|
||||
# https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu
|
||||
|
||||
build:android --crosstool_top=//external:android/crosstool
|
||||
build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||
build:android_arm --config=android
|
||||
build:android_arm --cpu=armeabi-v7a
|
||||
build:android_arm --fat_apk_cpu=armeabi-v7a
|
||||
build:android_arm64 --config=android
|
||||
build:android_arm64 --cpu=arm64-v8a
|
||||
build:android_arm64 --fat_apk_cpu=arm64-v8a
|
||||
build:android_x86 --config=android
|
||||
build:android_x86 --cpu=x86
|
||||
build:android_x86 --fat_apk_cpu=x86
|
||||
build:android_x86_64 --config=android
|
||||
build:android_x86_64 --cpu=x86_64
|
||||
build:android_x86_64 --fat_apk_cpu=x86_64
|
||||
|
||||
# iOS configs for each architecture and the fat binary builds.
|
||||
build:ios --apple_platform_type=ios
|
||||
build:ios --apple_bitcode=embedded --copt=-fembed-bitcode
|
||||
build:ios --copt=-Wno-c++11-narrowing
|
||||
build:ios_armv7 --config=ios
|
||||
build:ios_armv7 --cpu=ios_armv7
|
||||
build:ios_arm64 --config=ios
|
||||
build:ios_arm64 --cpu=ios_arm64
|
||||
build:ios_x86_64 --config=ios
|
||||
build:ios_x86_64 --cpu=ios_x86_64
|
||||
build:ios_fat --config=ios
|
||||
build:ios_fat --ios_multi_cpus=armv7,arm64,x86_64
|
||||
|
||||
# TFLite build configs for generic embedded Linux
|
||||
build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain
|
||||
build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||
build:elinux_aarch64 --config=elinux
|
||||
build:elinux_aarch64 --cpu=aarch64
|
||||
build:elinux_aarch64 --distinct_host_configuration=true
|
||||
build:elinux_armhf --config=elinux
|
||||
build:elinux_armhf --cpu=armhf
|
||||
build:elinux_armhf --distinct_host_configuration=true
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
build:android --cxxopt=-std=c++14
|
||||
build:android --host_cxxopt=-std=c++14
|
||||
build:ios --cxxopt=-std=c++14
|
||||
build:ios --host_cxxopt=-std=c++14
|
||||
build:linux --cxxopt=-std=c++14
|
||||
build:linux --host_cxxopt=-std=c++14
|
||||
build:macos --cxxopt=-std=c++14
|
||||
build:macos --host_cxxopt=-std=c++14
|
||||
build:windows --cxxopt=/std:c++14
|
||||
build:windows --host_cxxopt=/std:c++14
|
||||
|
||||
# Config to use a mostly-static build and disable modular op registration
|
||||
# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
|
||||
# By default, TensorFlow will build with a dependence on
|
||||
# //tensorflow:libtensorflow_framework.so.
|
||||
build:monolithic --define framework_shared_object=false
|
||||
|
||||
# For projects which use TensorFlow as part of a Bazel build process, putting
|
||||
# nothing in a bazelrc will default to a monolithic build. The following line
|
||||
# opts in to modular op registration support by default.
|
||||
build --define framework_shared_object=true
|
||||
|
||||
# ASAN build
|
||||
build:asan --strip=never
|
||||
build:asan --copt -fsanitize=address
|
||||
build:asan --copt -DADDRESS_SANITIZER
|
||||
build:asan --copt -O1
|
||||
build:asan --copt -g
|
||||
build:asan --copt -fno-omit-frame-pointer
|
||||
build:asan --linkopt -fsanitize=address
|
||||
|
||||
# dbg config, as a shorthand for '--config=opt -c dbg'
|
||||
build:dbg --config=opt -c dbg
|
||||
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
|
||||
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
||||
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||
build:dbg --copt -DDEBUG_BUILD
|
||||
|
||||
build --define=use_fast_cpp_protos=true
|
||||
build --define=allow_oversize_protos=true
|
||||
|
||||
# TF uses `standalone`, which is deprecated.
|
||||
build --spawn_strategy=local
|
||||
build -c opt
|
||||
|
||||
# Adding "--cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" creates parity with TF
|
||||
# compilation options. It also addresses memory use due to
|
||||
# copy-on-write semantics of std::strings of the older ABI.
|
||||
build --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0
|
||||
|
||||
# Make Bazel print out all options from rc files.
|
||||
build --announce_rc
|
||||
|
||||
# Other build flags.
|
||||
build --define=grpc_no_ares=true
|
||||
|
||||
# See https://github.com/bazelbuild/bazel/issues/7362 for information on what
|
||||
# --incompatible_remove_legacy_whole_archive flag does.
|
||||
# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate
|
||||
# Tensorflow to the default, however test coverage wasn't enough to catch the
|
||||
# errors.
|
||||
# There is ongoing work on Bazel team's side to provide support for transitive
|
||||
# shared libraries. As part of migrating to transitive shared libraries, we
|
||||
# hope to provide a better mechanism for control over symbol exporting, and
|
||||
# then tackle this issue again.
|
||||
#
|
||||
# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library
|
||||
# archives in -whole_archive -no_whole_archive.
|
||||
build --noincompatible_remove_legacy_whole_archive
|
||||
|
||||
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
|
||||
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
|
||||
# https://github.com/tensorflow/community/pull/179
|
||||
build --noincompatible_prohibit_aapt1
|
||||
|
||||
# Build TF with C++ 17 features.
|
||||
build:c++17 --cxxopt=-std=c++1z
|
||||
build:c++17 --cxxopt=-stdlib=libc++
|
||||
build:c++1z --config=c++17
|
||||
|
||||
# Enable using platform specific build settings, except when cross-compiling for
|
||||
# mobile platforms.
|
||||
build --enable_platform_specific_config
|
||||
build:android --noenable_platform_specific_config
|
||||
build:ios --noenable_platform_specific_config
|
||||
|
||||
# Suppress all warning messages.
|
||||
build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
||||
build:verbose_logs --output_filter=
|
||||
build --config=short_logs
|
||||
|
||||
# Options to build TensorFlow 1.x or 2.x.
|
||||
build:v1 --define=tf_api_version=1
|
||||
build:v2 --define=tf_api_version=2
|
||||
build:v1 --action_env=TF2_BEHAVIOR=0
|
||||
build:v2 --action_env=TF2_BEHAVIOR=1
|
||||
build --config=v2
|
||||
test --config=v2
|
||||
|
||||
# Put user-specific options in .bazelrc.user
|
||||
try-import %workspace%/.bazelrc.user
|
1
third_party/tflite_support/src/.bazelversion
vendored
Normal file
1
third_party/tflite_support/src/.bazelversion
vendored
Normal file
@ -0,0 +1 @@
|
||||
3.7.2
|
9
third_party/tflite_support/src/README.md
vendored
9
third_party/tflite_support/src/README.md
vendored
@ -5,8 +5,8 @@ models onto mobile devices. It works cross-Platform and is supported on Java,
|
||||
C++ (WIP), and Swift (WIP). The TFLite Support project consists of the following
|
||||
major components:
|
||||
|
||||
* **TFLite Support Library**: a cross-platform library that helps to
|
||||
deploy TFLite models onto mobile devices.
|
||||
* **TFLite Support Library**: a cross-platform library that helps to deploy
|
||||
TFLite models onto mobile devices.
|
||||
* **TFLite Model Metadata**: (metadata populator and metadata extractor
|
||||
library): includes both human and machine readable information about what a
|
||||
model does and how to use the model.
|
||||
@ -55,6 +55,11 @@ Utils, you need to set up following env variables correctly:
|
||||
* `ANDROID_SDK_API_LEVEL`
|
||||
* `ANDROID_BUILD_TOOLS_VERSION`
|
||||
|
||||
## How to contribute
|
||||
|
||||
Please issue a pull request and assign @xunkai55 or @lu-wang-g for a code
|
||||
review.
|
||||
|
||||
## Contact us
|
||||
|
||||
Let us know what you think about TFLite Support by creating a
|
||||
|
217
third_party/tflite_support/src/WORKSPACE
vendored
217
third_party/tflite_support/src/WORKSPACE
vendored
@ -1,9 +1,58 @@
|
||||
workspace(name = "org_tensorflow_lite_support")
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external")
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file")
|
||||
load("@//third_party/py:python_configure.bzl", "python_configure")
|
||||
|
||||
http_file(
|
||||
name = "mobilebert_float",
|
||||
sha256 = "883bf5d40f0b0ae435326bb21ed0f4c9004b22c3fd1539383fd16d68623696dd",
|
||||
urls = ["https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1?lite-format=tflite"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "mobilebert_with_metadata",
|
||||
sha256 = "e79d3c70108bbdee02da657b679349cab46dbb859a05b599c76b53d98e82f272",
|
||||
urls = ["https://tfhub.dev/tensorflow/lite-model/mobilebert/1/metadata/1?lite-format=tflite"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "30k-clean",
|
||||
sha256 = "fefb02b667a6c5c2fe27602d28e5fb3428f66ab89c7d6f388e7c8d44a02d0336",
|
||||
urls = ["https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/bert_qa/30k-clean.model"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "mobilebert_vocab",
|
||||
sha256 = "07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3",
|
||||
urls = ["https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/bert_qa/mobilebert_vocab.txt"],
|
||||
)
|
||||
|
||||
|
||||
http_file(
|
||||
name = "albert",
|
||||
sha256 = "4a29c7063c518925960229f49dd03e8da5d6682001cf73037815dcd98afd728a",
|
||||
urls = ["https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "albert_with_metadata",
|
||||
sha256 = "8a8a91856b94b945e4a9f22f0332bbf105c3b6b878bb23abfc97eb89d3e8436a",
|
||||
urls = ["https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/metadata/1?lite-format=tflite"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "bert_nl_classifier",
|
||||
sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600",
|
||||
urls = ["https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/bert_nl_classifier/bert_nl_classifier.tflite"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "bert_nl_classifier_no_metadata",
|
||||
sha256 = "9b4554f6e28a72a3f40511964eed1ccf4e74cc074f81543cacca4faf169a173e",
|
||||
urls = ["https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/bert_nl_classifier/bert_nl_classifier_no_metadata.tflite"],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "io_bazel_rules_closure",
|
||||
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
|
||||
@ -14,6 +63,14 @@ http_archive(
|
||||
],
|
||||
)
|
||||
|
||||
# GoogleTest/GoogleMock framework. Used by most unit-tests.
|
||||
http_archive(
|
||||
name = "com_google_googletest",
|
||||
urls = ["https://github.com/google/googletest/archive/4ec4cd23f486bf70efcc5d2caa40f24368f752e3.zip"],
|
||||
strip_prefix = "googletest-4ec4cd23f486bf70efcc5d2caa40f24368f752e3",
|
||||
sha256 = "de682ea824bfffba05b4e33b67431c247397d6175962534305136aa06f92e049",
|
||||
)
|
||||
|
||||
# Apple and Swift rules.
|
||||
# https://github.com/bazelbuild/rules_apple/releases
|
||||
http_archive(
|
||||
@ -37,15 +94,21 @@ http_archive(
|
||||
],
|
||||
)
|
||||
|
||||
# tf-nightly-20200810
|
||||
# TF on 2021-11-09.
|
||||
TENSORFLOW_COMMIT = "6a144e7763914d3f6141a7cdc9cb116cc23425f9"
|
||||
TENSORFLOW_SHA256 = "cec9a514c09d2b171ad447f3413151b25a6c3d88d048148cced1e85db81f3617"
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
sha256 = "fc6d7c57cd9427e695a38ad00fb6ecc3f623bac792dd44ad73a3f85b338b68be",
|
||||
strip_prefix = "tensorflow-8a4ffe2e1ae722cff5306778df0cfca8b7f503fe",
|
||||
sha256 = TENSORFLOW_SHA256,
|
||||
strip_prefix = "tensorflow-" + TENSORFLOW_COMMIT,
|
||||
urls = [
|
||||
"https://github.com/tensorflow/tensorflow/archive/8a4ffe2e1ae722cff5306778df0cfca8b7f503fe.tar.gz",
|
||||
"https://github.com/tensorflow/tensorflow/archive/" + TENSORFLOW_COMMIT
|
||||
+ ".tar.gz",
|
||||
],
|
||||
patches = [
|
||||
# We need to rename lite/ios/BUILD.apple to lite/ios/BUILD.
|
||||
"@//third_party:tensorflow_lite_ios_build.patch",
|
||||
],
|
||||
patches = ["@//third_party:tensorflow_lite_ios_build.patch"],
|
||||
patch_args = ["-p1"],
|
||||
)
|
||||
|
||||
@ -60,11 +123,11 @@ gflags()
|
||||
third_party_http_archive(
|
||||
name = "pybind11",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.4.3.tar.gz",
|
||||
"https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.6.0.tar.gz",
|
||||
"https://github.com/pybind/pybind11/archive/v2.6.0.tar.gz",
|
||||
],
|
||||
sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d",
|
||||
strip_prefix = "pybind11-2.4.3",
|
||||
sha256 = "90b705137b69ee3b5fc655eaca66d0dc9862ea1759226f7ccd3098425ae69571",
|
||||
strip_prefix = "pybind11-2.6.0",
|
||||
build_file = "//third_party:pybind11.BUILD",
|
||||
)
|
||||
|
||||
@ -119,13 +182,14 @@ http_archive(
|
||||
],
|
||||
)
|
||||
|
||||
# ABSL cpp library lts_2020_02_25
|
||||
# Needed for absl/status
|
||||
# ABSL cpp library lts_2021_03_24 Patch2
|
||||
# See https://github.com/abseil/abseil-cpp/releases for details.
|
||||
# Needed for absl/status and absl/status:statusor
|
||||
http_archive(
|
||||
name = "com_google_absl",
|
||||
build_file = "//third_party:com_google_absl.BUILD",
|
||||
urls = [
|
||||
"https://github.com/abseil/abseil-cpp/archive/20200225.tar.gz",
|
||||
"https://github.com/abseil/abseil-cpp/archive/20210324.2.tar.gz",
|
||||
],
|
||||
# Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved.
|
||||
patches = [
|
||||
@ -134,8 +198,8 @@ http_archive(
|
||||
patch_args = [
|
||||
"-p1",
|
||||
],
|
||||
strip_prefix = "abseil-cpp-20200225",
|
||||
sha256 = "728a813291bdec2aa46eab8356ace9f75ac2ed9dfe2df5ab603c4e6c09f1c353"
|
||||
strip_prefix = "abseil-cpp-20210324.2",
|
||||
sha256 = "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f"
|
||||
)
|
||||
|
||||
http_archive(
|
||||
@ -175,12 +239,12 @@ http_archive(
|
||||
|
||||
http_archive(
|
||||
name = "libyuv",
|
||||
urls = ["https://chromium.googlesource.com/libyuv/libyuv/+archive/6d603ec3f57dafddc424ef895e5d903915e94ba6.tar.gz"],
|
||||
# Adding the constrain of sha256 and strip_prefix will cause failure.
|
||||
# It seems that the downloaded libyuv was different every time, so that
|
||||
# the specified sha256 and strip_prefix cannot match.
|
||||
# sha256 = "ce196c72858456baa8022fa4a0dc18b77d619265dbc0e3d58e25ad15ca402522",
|
||||
# strip_prefix = "libyuv-6d603ec3f57dafddc424ef895e5d903915e94ba6",
|
||||
urls = ["https://chromium.googlesource.com/libyuv/libyuv/+archive/39240f7149cffde62e3620344d222c8ab2c21178.tar.gz"],
|
||||
# Adding the constrain of sha256 and strip_prefix will cause failure as of
|
||||
# Jan 2021. It seems that the downloaded libyuv was different every time,
|
||||
# so that the specified sha256 and strip_prefix cannot match.
|
||||
# sha256 = "01c2e30eb8e83880f9ba382f6bece9c38cd5b07f9cadae46ef1d5a69e07fafaf",
|
||||
# strip_prefix = "libyuv-39240f7149cffde62e3620344d222c8ab2c21178",
|
||||
build_file = "//third_party:libyuv.BUILD",
|
||||
)
|
||||
|
||||
@ -243,18 +307,18 @@ http_archive(
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "com_google_protobuf",
|
||||
sha256 = "a79d19dcdf9139fa4b81206e318e33d245c4c9da1ffed21c87288ed4380426f9",
|
||||
strip_prefix = "protobuf-3.11.4",
|
||||
urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.11.4.tar.gz"],
|
||||
patches = [
|
||||
"@//third_party:com_google_protobuf_fixes.diff"
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
],
|
||||
name = "libedgetpu",
|
||||
sha256 = "a179016a5874c58db969a5edd3fecf57610604e751b5c4d6d82ad58c383ffd64",
|
||||
strip_prefix = "libedgetpu-ea1eaddbddece0c9ca1166e868f8fd03f4a3199e",
|
||||
urls = [
|
||||
"https://github.com/google-coral/libedgetpu/archive/ea1eaddbddece0c9ca1166e868f8fd03f4a3199e.tar.gz"
|
||||
],
|
||||
)
|
||||
|
||||
# Set up TensorFlow version for Coral.
|
||||
load("@libedgetpu//:workspace.bzl", "libedgetpu_dependencies")
|
||||
libedgetpu_dependencies(TENSORFLOW_COMMIT, TENSORFLOW_SHA256)
|
||||
|
||||
# AutoValue 1.6+ shades Guava, Auto Common, and JavaPoet. That's OK
|
||||
# because none of these jars become runtime dependencies.
|
||||
java_import_external(
|
||||
@ -317,12 +381,38 @@ java_import_external(
|
||||
default_visibility = ["@com_google_auto_value//:__pkg__"],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "robolectric",
|
||||
urls = ["https://github.com/robolectric/robolectric-bazel/archive/4.4.tar.gz"],
|
||||
strip_prefix = "robolectric-bazel-4.4",
|
||||
)
|
||||
load("@robolectric//bazel:robolectric.bzl", "robolectric_repositories")
|
||||
robolectric_repositories()
|
||||
|
||||
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
|
||||
|
||||
flatbuffers()
|
||||
|
||||
RULES_JVM_EXTERNAL_TAG = "3.2"
|
||||
|
||||
http_archive(
|
||||
name = "rules_jvm_external",
|
||||
strip_prefix = "rules_jvm_external-%s" % RULES_JVM_EXTERNAL_TAG,
|
||||
sha256 = "82262ff4223c5fda6fb7ff8bd63db8131b51b413d26eb49e3131037e79e324af",
|
||||
url = "https://github.com/bazelbuild/rules_jvm_external/archive/%s.zip" % RULES_JVM_EXTERNAL_TAG,
|
||||
)
|
||||
|
||||
load("@rules_jvm_external//:defs.bzl", "maven_install")
|
||||
|
||||
# Set up TF.
|
||||
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
|
||||
tf_workspace(tf_repo_name="@org_tensorflow")
|
||||
load("@org_tensorflow//tensorflow:workspace3.bzl", "workspace")
|
||||
workspace()
|
||||
load("@org_tensorflow//tensorflow:workspace2.bzl", "workspace") # buildifier: disable=load
|
||||
workspace()
|
||||
load("@org_tensorflow//tensorflow:workspace1.bzl", "workspace") # buildifier: disable=load
|
||||
workspace()
|
||||
load("@org_tensorflow//tensorflow:workspace0.bzl", "workspace") # buildifier: disable=load
|
||||
workspace()
|
||||
|
||||
load("//third_party/tensorflow:tf_configure.bzl", "tf_configure")
|
||||
tf_configure(name = "local_config_tf")
|
||||
@ -346,35 +436,38 @@ apple_support_dependencies()
|
||||
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
|
||||
bazel_version_repository(name = "bazel_version")
|
||||
|
||||
|
||||
# Set up Android.
|
||||
load("//third_party/android:android_configure.bzl", "android_configure")
|
||||
android_configure(name="local_config_android")
|
||||
load("@local_config_android//:android.bzl", "android_workspace")
|
||||
android_workspace()
|
||||
|
||||
python_configure(name = "local_config_python")
|
||||
|
||||
ATS_TAG = "androidx-test-1.3.0"
|
||||
http_archive(
|
||||
name = "android_test_support",
|
||||
strip_prefix = "android-test-%s" % ATS_TAG,
|
||||
urls = ["https://github.com/android/android-test/archive/%s.tar.gz" % ATS_TAG],
|
||||
)
|
||||
load("@android_test_support//:repo.bzl", "android_test_repositories")
|
||||
android_test_repositories()
|
||||
|
||||
# Maven dependencies.
|
||||
|
||||
RULES_JVM_EXTERNAL_TAG = "3.2"
|
||||
|
||||
http_archive(
|
||||
name = "rules_jvm_external",
|
||||
strip_prefix = "rules_jvm_external-%s" % RULES_JVM_EXTERNAL_TAG,
|
||||
sha256 = "82262ff4223c5fda6fb7ff8bd63db8131b51b413d26eb49e3131037e79e324af",
|
||||
url = "https://github.com/bazelbuild/rules_jvm_external/archive/%s.zip" % RULES_JVM_EXTERNAL_TAG,
|
||||
)
|
||||
|
||||
load("@rules_jvm_external//:defs.bzl", "maven_install")
|
||||
|
||||
maven_install(
|
||||
artifacts = [
|
||||
"androidx.annotation:annotation:aar:1.1.0",
|
||||
"androidx.annotation:annotation-experimental:1.1.0",
|
||||
"androidx.multidex:multidex:jar:2.0.1",
|
||||
"androidx.test:core:jar:1.3.0",
|
||||
"androidx.test.ext:junit:jar:1.1.2",
|
||||
"androidx.test:runner:jar:1.3.0",
|
||||
"com.google.android.odml:image:aar:1.0.0-beta1",
|
||||
"com.google.truth:truth:jar:1.1",
|
||||
"commons-io:commons-io:jar:2.8.0",
|
||||
# Mockito >= 3.4.6 cannot pass bazel desugar.
|
||||
"org.mockito:mockito-android:jar:3.0.0",
|
||||
"org.mockito:mockito-core:jar:3.0.0",
|
||||
"org.mockito:mockito-inline:jar:3.0.0",
|
||||
"org.robolectric:robolectric:jar:4.4",
|
||||
"junit:junit:jar:4.13",
|
||||
],
|
||||
repositories = [
|
||||
"https://jcenter.bintray.com",
|
||||
"https://maven.google.com",
|
||||
"https://dl.google.com/dl/android/maven2",
|
||||
"https://repo1.maven.org/maven2",
|
||||
@ -382,3 +475,23 @@ maven_install(
|
||||
fetch_sources = True,
|
||||
version_conflict_policy = "pinned",
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "tf_toolchains",
|
||||
sha256 = "d72b2e52baf0592f5b94347b128ef75422fc22f63dfcf2d5fd46bc732cab052b",
|
||||
strip_prefix = "toolchains-1.3.0",
|
||||
urls = [
|
||||
"http://mirror.tensorflow.org/github.com/tensorflow/toolchains/archive/v1.3.0.tar.gz",
|
||||
"https://github.com/tensorflow/toolchains/archive/v1.3.0.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
load("@tf_toolchains//toolchains/embedded/arm-linux:arm_linux_toolchain_configure.bzl", "arm_linux_toolchain_configure")
|
||||
|
||||
# TFLite crossbuild toolchain for embeddeds Linux
|
||||
arm_linux_toolchain_configure(
|
||||
name = "local_config_embedded_arm",
|
||||
build_file = "@tf_toolchains//toolchains/embedded/arm-linux:BUILD",
|
||||
aarch64_repo = "../aarch64_linux_toolchain",
|
||||
armhf_repo = "../armhf_linux_toolchain",
|
||||
)
|
||||
|
@ -8,17 +8,21 @@ package(
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
# LINT.IfChange
|
||||
package_group(
|
||||
name = "users",
|
||||
includes = [
|
||||
":internal",
|
||||
],
|
||||
packages = [
|
||||
"//tensorflow_lite_support/...",
|
||||
"//third_party/py/tensorflow_examples/...",
|
||||
"//third_party/tensorflow_models/...",
|
||||
],
|
||||
)
|
||||
# Remove internal path from tensorflow_lite_support:users in the copybara file.
|
||||
# LINT.ThenChange(//tensorflow_lite_support/copy.bara.sky)
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//tensorflow_lite_support/...",
|
||||
],
|
||||
)
|
||||
|
||||
# Config setting for determining if we are building for Android.
|
||||
config_setting(
|
||||
|
18
third_party/tflite_support/src/tensorflow_lite_support/acceleration/README.md
vendored
Normal file
18
third_party/tflite_support/src/tensorflow_lite_support/acceleration/README.md
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
# Acceleration allowlisting
|
||||
|
||||
A complementary directory for the work of
|
||||
[accelerator allowlisting](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/experimental/acceleration)
|
||||
in TensorFlow Lite.
|
||||
|
||||
## Coral Edge TPU plugin
|
||||
|
||||
The Coral Edge TPU delegate plugin used in the
|
||||
[acceleration library](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/port/default/tflite_wrapper.h).
|
||||
See
|
||||
[CoralSettings](https://github.com/tensorflow/tensorflow/blob/896fecee319ffeb4af2a3c0b5436f3a55ab058fa/tensorflow/lite/experimental/acceleration/configuration/configuration.proto#L323)
|
||||
about how to configure the Coral Edge TPU plugin. You can use the acceleration
|
||||
library together with
|
||||
[Task Library](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/cc/task).
|
||||
Configure your desired accelerator, including the Coral plugin through the
|
||||
options of each task, i.e.
|
||||
[image_classifier_options](https://github.com/tensorflow/tflite-support/blob/43f1267b99f1dbc27c7c5b2e1111e1ff6b9121ea/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto#L79).
|
136
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/BUILD
vendored
Normal file
136
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/BUILD
vendored
Normal file
@ -0,0 +1,136 @@
|
||||
load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow_lite_support:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_plugin",
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:gpu_plugin",
|
||||
],
|
||||
alwayslink = 1, # For registration to always run.
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nnapi_plugin",
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:nnapi_plugin",
|
||||
],
|
||||
alwayslink = 1, # For registration to always run.
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hexagon_plugin",
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:hexagon_plugin",
|
||||
],
|
||||
alwayslink = 1, # For registration to always run.
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xnnpack_plugin",
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:xnnpack_plugin",
|
||||
],
|
||||
alwayslink = 1, # For registration to always run.
|
||||
)
|
||||
|
||||
# To use the edgetpu_coral_plugin externally, add the following flags to the bazel command:
|
||||
# --define darwinn_portable=1
|
||||
cc_library(
|
||||
name = "edgetpu_coral_plugin",
|
||||
srcs = ["edgetpu_coral_plugin.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_glog//:glog",
|
||||
"@libedgetpu//tflite/public:edgetpu_c",
|
||||
"@libedgetpu//tflite/public:oss_edgetpu_direct_all", # buildcleaner: keep
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs",
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:delegate_registry",
|
||||
],
|
||||
alwayslink = 1, # For registration to always run.
|
||||
)
|
||||
|
||||
# To test it externally, plugin a Coral device, and run the following command:
|
||||
# bazel test tensorflow_lite_support/acceleration/configuration:edgetpu_coral_plugin_test \
|
||||
# --define darwinn_portable=1
|
||||
cc_test(
|
||||
name = "edgetpu_coral_plugin_test",
|
||||
srcs = ["edgetpu_coral_plugin_test.cc"],
|
||||
data = [
|
||||
"//tensorflow_lite_support/acceleration/configuration/testdata:test_files",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap", # Requires edge TPU device.
|
||||
],
|
||||
deps = [
|
||||
":edgetpu_coral_plugin",
|
||||
"//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs",
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:delegate_registry",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
# Targets for delegate plugin library Maven release.
|
||||
|
||||
# GPU delegate plugin library.
|
||||
tflite_jni_binary(
|
||||
name = "libgpu_delegate_plugin.so",
|
||||
linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration/c:gpu_plugin",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_delegate_plugin_native",
|
||||
srcs = [
|
||||
":libgpu_delegate_plugin.so",
|
||||
],
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
# Android target of Acceleration@Scale GPU plugin.
|
||||
# Use this target when GPU delegate is selected in the Task Library Java API.
|
||||
android_library(
|
||||
name = "gpu_delegate_plugin_android",
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
exports = [":gpu_delegate_plugin_native"],
|
||||
)
|
||||
|
||||
# AAR target of Acceleration@Scale GPU acceleration for OSS release.
|
||||
aar_with_jni(
|
||||
name = "gpu-delegate-plugin",
|
||||
android_library = ":gpu_delegate_plugin_android",
|
||||
)
|
170
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
vendored
Normal file
170
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin.cc
vendored
Normal file
@ -0,0 +1,170 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include "absl/container/node_hash_map.h" // from @com_google_absl
|
||||
#include "absl/memory/memory.h" // from @com_google_absl
|
||||
#include "absl/strings/match.h" // from @com_google_absl
|
||||
#include "absl/strings/numbers.h" // from @com_google_absl
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
|
||||
#include "tflite/public/edgetpu_c.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
namespace {
|
||||
|
||||
constexpr int kDEFAULT_USB_MAX_BULK_IN_QUEUE_LENGTH = 32;
|
||||
constexpr char kUsb[] = "usb";
|
||||
constexpr char kPci[] = "pci";
|
||||
|
||||
inline std::string ConvertPerformance(
|
||||
const CoralSettings_::Performance& from_performance) {
|
||||
switch (from_performance) {
|
||||
case CoralSettings_::Performance_LOW:
|
||||
return "Low";
|
||||
case CoralSettings_::Performance_MEDIUM:
|
||||
return "Medium";
|
||||
case CoralSettings_::Performance_HIGH:
|
||||
return "High";
|
||||
default:
|
||||
return "Max";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string ConvertBool(bool from_bool) {
|
||||
return from_bool ? "True" : "False";
|
||||
}
|
||||
|
||||
bool MatchDevice(const std::string& device,
|
||||
const std::string& type,
|
||||
int* index) {
|
||||
const auto prefix(type + ":");
|
||||
if (!absl::StartsWith(device, prefix))
|
||||
return false;
|
||||
if (!absl::SimpleAtoi(device.substr(prefix.size()), index))
|
||||
return false;
|
||||
if (*index < 0)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// device_index corresponds to specific device type, e.g. "usb:0" means the
|
||||
// first USB device or "pci:0" means the first PCIe device.
|
||||
TfLiteDelegate* CreateEdgeTpuDelegate(
|
||||
absl::optional<edgetpu_device_type> device_type,
|
||||
absl::optional<int> device_index,
|
||||
const absl::node_hash_map<std::string, std::string>& device_options) {
|
||||
std::vector<edgetpu_option> options(device_options.size());
|
||||
size_t i = 0;
|
||||
for (auto& device_option : device_options) {
|
||||
options[i++] = {device_option.first.c_str(), device_option.second.c_str()};
|
||||
}
|
||||
|
||||
size_t num_devices;
|
||||
std::unique_ptr<edgetpu_device, decltype(&edgetpu_free_devices)> devices(
|
||||
edgetpu_list_devices(&num_devices), &edgetpu_free_devices);
|
||||
|
||||
if (!device_index.has_value()) {
|
||||
return CreateEdgeTpuDelegate(device_type, 0, device_options);
|
||||
} else {
|
||||
const int index = device_index.value();
|
||||
if (device_type.has_value()) {
|
||||
int type_index = 0;
|
||||
for (size_t i = 0; i < num_devices; i++) {
|
||||
const auto& device = devices.get()[i];
|
||||
if (device.type == device_type.value() && type_index++ == index)
|
||||
return edgetpu_create_delegate(device.type, device.path,
|
||||
options.data(), options.size());
|
||||
}
|
||||
} else {
|
||||
if (index < num_devices)
|
||||
return edgetpu_create_delegate(devices.get()[index].type,
|
||||
devices.get()[index].path,
|
||||
options.data(), options.size());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteDelegate* CreateEdgeTpuDelegate(
|
||||
const std::string& device,
|
||||
const absl::node_hash_map<std::string, std::string>& options) {
|
||||
if (device.empty()) {
|
||||
return CreateEdgeTpuDelegate(absl::nullopt, absl::nullopt, options);
|
||||
} else if (device == kUsb) {
|
||||
return CreateEdgeTpuDelegate(EDGETPU_APEX_USB, absl::nullopt, options);
|
||||
} else if (device == kPci) {
|
||||
return CreateEdgeTpuDelegate(EDGETPU_APEX_PCI, absl::nullopt, options);
|
||||
} else {
|
||||
int index;
|
||||
if (MatchDevice(device, "", &index)) {
|
||||
return CreateEdgeTpuDelegate(absl::nullopt, index, options);
|
||||
} else if (MatchDevice(device, kUsb, &index)) {
|
||||
return CreateEdgeTpuDelegate(EDGETPU_APEX_USB, index, options);
|
||||
} else if (MatchDevice(device, kPci, &index)) {
|
||||
return CreateEdgeTpuDelegate(EDGETPU_APEX_PCI, index, options);
|
||||
} else {
|
||||
LOG(ERROR) << "Cannot match the given device string (" << device
|
||||
<< ") with a Coral device.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class EdgeTpuCoralPlugin : public DelegatePluginInterface {
|
||||
public:
|
||||
TfLiteDelegatePtr Create() override {
|
||||
return TfLiteDelegatePtr(CreateEdgeTpuDelegate(device_, options_),
|
||||
edgetpu_free_delegate);
|
||||
}
|
||||
|
||||
int GetDelegateErrno(TfLiteDelegate* from_delegate) override { return 0; }
|
||||
|
||||
static std::unique_ptr<DelegatePluginInterface> New(
|
||||
const TFLiteSettings& acceleration) {
|
||||
return absl::make_unique<EdgeTpuCoralPlugin>(acceleration);
|
||||
}
|
||||
|
||||
explicit EdgeTpuCoralPlugin(const TFLiteSettings& tflite_settings) {
|
||||
const auto* coral_settings = tflite_settings.coral_settings();
|
||||
if (!coral_settings) {
|
||||
return;
|
||||
}
|
||||
|
||||
device_ = coral_settings->device()->str();
|
||||
options_.insert(
|
||||
{"Performance", ConvertPerformance(coral_settings->performance())});
|
||||
options_.insert(
|
||||
{"Usb.AlwaysDfu", ConvertBool(coral_settings->usb_always_dfu())});
|
||||
options_.insert(
|
||||
{"Usb.MaxBulkInQueueLength",
|
||||
std::to_string(coral_settings->usb_max_bulk_in_queue_length() == 0
|
||||
? kDEFAULT_USB_MAX_BULK_IN_QUEUE_LENGTH
|
||||
: coral_settings->usb_max_bulk_in_queue_length())});
|
||||
}
|
||||
|
||||
private:
|
||||
std::string device_;
|
||||
absl::node_hash_map<std::string, std::string> options_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TFLITE_REGISTER_DELEGATE_FACTORY_FUNCTION(EdgeTpuCoralPlugin,
|
||||
EdgeTpuCoralPlugin::New);
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
100
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
vendored
Normal file
100
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/edgetpu_coral_plugin_test.cc
vendored
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
namespace {
|
||||
|
||||
constexpr char kEdgeTpuModelFilePath[] =
|
||||
"tensorflow_lite_support/acceleration/configuration/testdata/"
|
||||
"mobilenet_v1_1.0_224_quant_edgetpu.tflite";
|
||||
constexpr char kRegularModelFilePath[] =
|
||||
"tensorflow_lite_support/acceleration/configuration/testdata/"
|
||||
"mobilenet_v1_1.0_224_quant.tflite";
|
||||
constexpr char kImagePath[] =
|
||||
"tensorflow_lite_support/acceleration/configuration/testdata/"
|
||||
"burger.jpg";
|
||||
|
||||
using ::tflite::task::vision::DecodeImageFromFile;
|
||||
using ::tflite::task::vision::ImageData;
|
||||
using ::tflite::task::vision::ImageDataFree;
|
||||
|
||||
using EdgeTpuCoralPluginTest = testing::TestWithParam<std::string>;
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(CoralPluginTests,
|
||||
EdgeTpuCoralPluginTest,
|
||||
testing::Values(kRegularModelFilePath,
|
||||
kEdgeTpuModelFilePath));
|
||||
|
||||
TEST_P(EdgeTpuCoralPluginTest, CreateEdgeTpuCoralPlugin) {
|
||||
// Create the Coral delegate from the Coral plugin.
|
||||
flatbuffers::FlatBufferBuilder flatbuffer_builder;
|
||||
auto settings = flatbuffers::GetTemporaryPointer(
|
||||
flatbuffer_builder,
|
||||
CreateTFLiteSettings(flatbuffer_builder, tflite::Delegate_EDGETPU_CORAL));
|
||||
auto plugin = ::tflite::delegates::DelegatePluginRegistry::CreateByName(
|
||||
"EdgeTpuCoralPlugin", *settings);
|
||||
auto coral_delegate = plugin->Create();
|
||||
|
||||
// Load the tflite model file.
|
||||
std::unique_ptr<::tflite::FlatBufferModel> tflite_model =
|
||||
::tflite::FlatBufferModel::BuildFromFile(GetParam().c_str());
|
||||
ASSERT_NE(tflite_model, nullptr);
|
||||
|
||||
// Create the tflite interpreter.
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
std::unique_ptr<::tflite::Interpreter> interpreter;
|
||||
ASSERT_EQ(::tflite::InterpreterBuilder(*tflite_model, resolver)(&interpreter),
|
||||
kTfLiteOk);
|
||||
ASSERT_NE(interpreter, nullptr);
|
||||
interpreter->ModifyGraphWithDelegate(coral_delegate.get());
|
||||
|
||||
// Verifies that interpreter runs correctly.
|
||||
// To open source the code under tensorflow/lite, the following code needs to
|
||||
// be stript from the Task library dependency, meaning forking or rewriting
|
||||
// `LoadImage` and `ImageData`.
|
||||
// `ASSERT_OK_AND_ASSIGN` is not available externally.
|
||||
auto rgb_image_or = DecodeImageFromFile(kImagePath);
|
||||
ASSERT_TRUE(rgb_image_or.ok());
|
||||
|
||||
ImageData rgb_image = rgb_image_or.value();
|
||||
const uint8_t* input_data = rgb_image.pixel_data;
|
||||
size_t input_data_byte_size =
|
||||
rgb_image.width * rgb_image.height * rgb_image.channels * sizeof(uint8_t);
|
||||
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
uint8_t* input_tensor = interpreter->typed_input_tensor<uint8_t>(0);
|
||||
memcpy(input_tensor, input_data, input_data_byte_size);
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
uint8_t* output_tensor = interpreter->typed_output_tensor<uint8_t>(0);
|
||||
// `cheeseburger` is the 935th item in the label file of
|
||||
// "mobilenet_v1_1.0_224_quant_edgetpu.tflite". See labels.txt.
|
||||
EXPECT_EQ(output_tensor[934], 255);
|
||||
ImageDataFree(&rgb_image);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
12
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/BUILD
vendored
Normal file
12
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/BUILD
vendored
Normal file
@ -0,0 +1,12 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow_lite_support:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "test_files",
|
||||
srcs = glob([
|
||||
"*.tflite",
|
||||
"*.jpg",
|
||||
]),
|
||||
)
|
BIN
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/burger.jpg
vendored
Normal file
BIN
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/burger.jpg
vendored
Normal file
Binary file not shown.
After ![]() (image error) Size: 75 KiB |
BIN
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/mobilenet_v1_1.0_224_quant.tflite
vendored
Normal file
BIN
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/mobilenet_v1_1.0_224_quant.tflite
vendored
Normal file
Binary file not shown.
BIN
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/mobilenet_v1_1.0_224_quant_edgetpu.tflite
vendored
Normal file
BIN
third_party/tflite_support/src/tensorflow_lite_support/acceleration/configuration/testdata/mobilenet_v1_1.0_224_quant_edgetpu.tflite
vendored
Normal file
Binary file not shown.
23
third_party/tflite_support/src/tensorflow_lite_support/c/BUILD
vendored
Normal file
23
third_party/tflite_support/src/tensorflow_lite_support/c/BUILD
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow_lite_support:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = ["common.cc"],
|
||||
hdrs = ["common.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common_utils",
|
||||
srcs = ["common_utils.cc"],
|
||||
hdrs = ["common_utils.h"],
|
||||
deps = [
|
||||
":common",
|
||||
"//tensorflow_lite_support/cc:common",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
],
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -12,16 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#inmport "base/strings/sys_string_conversions.h"
|
||||
#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"
|
||||
|
||||
std::string MakeString(NSString* str) {
|
||||
return SysNSStringToUTF8(str);
|
||||
}
|
||||
#include "tensorflow_lite_support/c/common.h"
|
||||
|
||||
NSString* MakeNSString(const std::string& str) {
|
||||
return [[NSString alloc]
|
||||
initWithBytes:const_cast<void*>(static_cast<const void*>(str.data()))
|
||||
length:str.length()
|
||||
encoding:NSUTF8StringEncoding];
|
||||
#include <cstdlib>
|
||||
|
||||
void TfLiteSupportErrorDelete(TfLiteSupportError* error) {
|
||||
// `strdup` obtains memory using `malloc` and the memory needs to be
|
||||
// released using `free`.
|
||||
free(error->message);
|
||||
delete error;
|
||||
}
|
202
third_party/tflite_support/src/tensorflow_lite_support/c/common.h
vendored
Normal file
202
third_party/tflite_support/src/tensorflow_lite_support/c/common.h
vendored
Normal file
@ -0,0 +1,202 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_COMMON_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_COMMON_H_
|
||||
|
||||
// Defines C struct and error codes for describing any error returned from the C
|
||||
// Task Library.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// Error codes for TensorFlow Lite Task Library C APIs.
|
||||
//
|
||||
// Holds one to one mapping with `TfLiteSupportStatus` code starting from kError
|
||||
// = 1. Omits `kOk` since `TfLiteErrorCode` is only to be used in the event of
|
||||
// an error and does not account for success unlike `TfLiteSupportStatus`. In
|
||||
// case of success, TensorFlow Lite Task Library C APIs return the appropriate
|
||||
// return value and a null error. One to one mapping makes it easier to convert
|
||||
// between `TfLiteSupportStatus` and `TfLiteSupportErrorCode` without long
|
||||
// switch statements.
|
||||
//
|
||||
// Also holds error codes mapping to absl::Status::code() starting from
|
||||
// kNotFound = 900 in cases where the absl::Status payload can't
|
||||
// be mapped to a `TfLiteSupportStatus` code. kErrorCodeFirst and kErrorCodeLast
|
||||
// are also provided for safety checks during conversion between
|
||||
// `TfLiteSupportStatus` and `TfLiteSupportErrorCode`. In case of modifications
|
||||
// in error codes, ensure that kErrorCodeFirst and kErrorCodeLast is
|
||||
// respectively, set to the least and greatest enum value amongst the error
|
||||
// codes mapping to TfLiteSupportStatus.
|
||||
enum TfLiteSupportErrorCode {
|
||||
// Unspecified error.
|
||||
kError = 1,
|
||||
// Invalid argument specified.
|
||||
kInvalidArgumentError = 2,
|
||||
// Invalid FlatBuffer file or buffer specified.
|
||||
kInvalidFlatBufferError = 3,
|
||||
// Model contains a builtin op that isn't supported by the OpResolver or
|
||||
// delegates.
|
||||
kUnsupportedBuiltinOpError = 4,
|
||||
// Model contains a custom op that isn't supported by the OpResolver or
|
||||
// delegates.
|
||||
kUnsupportedCustomOpError = 5,
|
||||
|
||||
// File I/O error codes.
|
||||
|
||||
// No such file.
|
||||
kFileNotFoundError = 100,
|
||||
// Permission issue.
|
||||
kFilePermissionDeniedError,
|
||||
// I/O error when reading file.
|
||||
kFileReadError,
|
||||
// I/O error when mmap-ing file.
|
||||
kFileMmapError,
|
||||
|
||||
// TensorFlow Lite metadata error codes.
|
||||
|
||||
// Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer.
|
||||
kMetadataInvalidSchemaVersionError = 200,
|
||||
// No such associated file within metadata, or file has not been packed.
|
||||
kMetadataAssociatedFileNotFoundError,
|
||||
// ZIP I/O error when unpacking an associated file.
|
||||
kMetadataAssociatedFileZipError,
|
||||
// Inconsistency error between the metadata and actual TF Lite model.
|
||||
// E.g.: number of labels and output tensor values differ.
|
||||
kMetadataInconsistencyError,
|
||||
// Invalid process units specified.
|
||||
// E.g.: multiple ProcessUnits with the same type for a given tensor.
|
||||
kMetadataInvalidProcessUnitsError,
|
||||
// Inconsistency error with the number of labels.
|
||||
// E.g.: label files for different locales have a different number of labels.
|
||||
kMetadataNumLabelsMismatchError,
|
||||
// Score calibration parameters parsing error.
|
||||
// E.g.: too many parameters provided in the corresponding associated file.
|
||||
kMetadataMalformedScoreCalibrationError,
|
||||
// Unexpected number of subgraphs for the current task.
|
||||
// E.g.: image classification expects a single subgraph.
|
||||
kMetadataInvalidNumSubgraphsError,
|
||||
// A given tensor requires NormalizationOptions but none were found.
|
||||
// E.g.: float input tensor requires normalization to preprocess input images.
|
||||
kMetadataMissingNormalizationOptionsError,
|
||||
// Invalid ContentProperties specified.
|
||||
// E.g. expected ImageProperties, got BoundingBoxProperties.
|
||||
kMetadataInvalidContentPropertiesError,
|
||||
// Metadata is mandatory but was not found.
|
||||
// E.g. current task requires TFLite Model Metadata but none was found.
|
||||
kMetadataNotFoundError,
|
||||
// Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but
|
||||
// none was found or it was empty.
|
||||
// E.g. current task requires labels but none were found.
|
||||
kMetadataMissingLabelsError,
|
||||
// The ProcessingUnit for tokenizer is not correctly configured.
|
||||
// E.g BertTokenizer doesn't have a valid vocab file associated.
|
||||
kMetadataInvalidTokenizerError,
|
||||
|
||||
// Input tensor(s) error codes.
|
||||
|
||||
// Unexpected number of input tensors for the current task.
|
||||
// E.g. current task expects a single input tensor.
|
||||
kInvalidNumInputTensorsError = 300,
|
||||
// Unexpected input tensor dimensions for the current task.
|
||||
// E.g.: only 4D input tensors supported.
|
||||
kInvalidInputTensorDimensionsError,
|
||||
// Unexpected input tensor type for the current task.
|
||||
// E.g.: current task expects a uint8 pixel image as input.
|
||||
kInvalidInputTensorTypeError,
|
||||
// Unexpected input tensor bytes size.
|
||||
// E.g.: size in bytes does not correspond to the expected number of pixels.
|
||||
kInvalidInputTensorSizeError,
|
||||
// No correct input tensor found for the model.
|
||||
// E.g.: input tensor name is not part of the text model's input tensors.
|
||||
kInputTensorNotFoundError,
|
||||
|
||||
// Output tensor(s) error codes.
|
||||
|
||||
// Unexpected output tensor dimensions for the current task.
|
||||
// E.g.: only a batch size of 1 is supported.
|
||||
kInvalidOutputTensorDimensionsError = 400,
|
||||
// Unexpected input tensor type for the current task.
|
||||
// E.g.: multi-head model with different output tensor types.
|
||||
kInvalidOutputTensorTypeError,
|
||||
// No correct output tensor found for the model.
|
||||
// E.g.: output tensor name is not part of the text model's output tensors.
|
||||
kOutputTensorNotFoundError,
|
||||
// Unexpected number of output tensors for the current task.
|
||||
// E.g.: current task expects a single output tensor.
|
||||
kInvalidNumOutputTensorsError,
|
||||
|
||||
// Image processing error codes.
|
||||
|
||||
// Unspecified image processing failures.
|
||||
kImageProcessingError = 500,
|
||||
// Unexpected input or output buffer metadata.
|
||||
// E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees.
|
||||
kImageProcessingInvalidArgumentError,
|
||||
// Image processing operation failures.
|
||||
// E.g. libyuv rotation failed for an unknown reason.
|
||||
kImageProcessingBackendError,
|
||||
|
||||
// Convenience error codes for condition checks during type casting.
|
||||
//
|
||||
// Codes mapping to absl status codes should not be considered for these
|
||||
// ranges.
|
||||
// They must be used exclsively for checking if error codes fall in valid
|
||||
// ranges when converting between TfLiteSupportStatus and
|
||||
// TfLiteSupportErrorCodee.
|
||||
|
||||
// Ensure it holds the least enum value amongst error codes mapping to
|
||||
// TfLiteSupportStatus.
|
||||
kErrorCodeFirst = kError,
|
||||
// Ensure it holds the greatest enum value amongst error codes mapping to
|
||||
// TfLiteSupportStatus.
|
||||
kErrorCodeLast = kImageProcessingBackendError,
|
||||
|
||||
// Absl Status Codes Mapping
|
||||
//
|
||||
// Codes starting from 900 will be used to map absl::Status created by TfLite
|
||||
// and are used as is by TfLite Support C++ layer. Such absl status objects
|
||||
// don't have a TfLiteSupportStatus in the payload that can be mapped to other
|
||||
// error codes in this struct. You must use the absl::Status::code() and map
|
||||
// them to the following error codes in such cases.
|
||||
// For more info on respective absl status codes, please see:
|
||||
// https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L91
|
||||
|
||||
// kNotFound indicates some requested entity (such as a file or directory)
|
||||
// was not found.
|
||||
kNotFoundError = 900,
|
||||
// kInternal indicates an internal error has occurred
|
||||
// and some invariants expected by the underlying system have not been
|
||||
// satisfied. This error code is reserved for serious errors.
|
||||
kInternalError,
|
||||
};
|
||||
|
||||
// A `TfLiteSupportError` encapsulates an error code and a descriptive message
|
||||
// to return in the event of an error being encountered in any TensorFlow Lite
|
||||
// Task Library C API.
|
||||
typedef struct TfLiteSupportError {
|
||||
// Holds the error code.
|
||||
enum TfLiteSupportErrorCode code;
|
||||
// Detailed description of the error.
|
||||
char* message;
|
||||
} TfLiteSupportError;
|
||||
|
||||
void TfLiteSupportErrorDelete(TfLiteSupportError* error);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_COMMON_H_
|
111
third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
vendored
Normal file
111
third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.cc
vendored
Normal file
@ -0,0 +1,111 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow_lite_support/c/common_utils.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h" // from @com_google_absl
|
||||
#include "absl/strings/cord.h" // from @com_google_absl
|
||||
#include "tensorflow_lite_support/cc/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
||||
void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
|
||||
const char* message,
|
||||
TfLiteSupportError** error) {
|
||||
if (error == nullptr)
|
||||
return;
|
||||
|
||||
*error = new TfLiteSupportError;
|
||||
(*error)->code = code;
|
||||
(*error)->message = strdup(message);
|
||||
}
|
||||
|
||||
void CreateTfLiteSupportErrorWithStatus(const absl::Status& status,
|
||||
TfLiteSupportError** error) {
|
||||
if (status.ok() || error == nullptr)
|
||||
return;
|
||||
|
||||
// Payload of absl::Status created by the tflite task library stores an
|
||||
// appropriate value of the enum TfLiteSupportStatus. The integer value
|
||||
// corresponding to the TfLiteSupportStatus enum stored in the payload is
|
||||
// extracted here to later map to the appropriate error code to be returned.
|
||||
// In cases where the enum is not stored in (payload is NULL or the payload
|
||||
// string cannot be converted to an integer), we set the error code value to
|
||||
// be 1 (kError of TfLiteErrorCode used in the C library to signify any errors
|
||||
// not falling into other categories.) Since payload is of type absl::Cord
|
||||
// that can be type cast into an absl::optional<std::string>, we use the
|
||||
// std::stoi function to convert it into an integer code if possible.
|
||||
int generic_error_code = static_cast<int>(kError);
|
||||
int error_code;
|
||||
try {
|
||||
// Try converting payload to integer if payload is not empty. Otherwise
|
||||
// convert a string signifying generic error code kError to integer.
|
||||
error_code = std::stoi(static_cast<absl::optional<std::string>>(
|
||||
status.GetPayload(kTfLiteSupportPayload))
|
||||
.value_or(std::to_string(generic_error_code)));
|
||||
} catch (std::invalid_argument& e) {
|
||||
// If non empty payload string cannot be converted to an integer. Set error
|
||||
// code to 1(kError).
|
||||
error_code = generic_error_code;
|
||||
}
|
||||
|
||||
// If error_code is outside the range of enum values possible or is kError, we
|
||||
// try to map the absl::Status::code() to assign appropriate
|
||||
// TfLiteSupportErrorCode or kError in default cases. Note: The mapping to
|
||||
// absl::Status::code() is done to generate a more specific error code than
|
||||
// kError in cases when the payload can't be mapped to TfLiteSupportStatus.
|
||||
// This can happen when absl::Status returned by TfLite are in turn returned
|
||||
// without moodification by TfLite Support Methods.
|
||||
if (error_code > static_cast<int>(kErrorCodeLast) ||
|
||||
error_code <= static_cast<int>(kErrorCodeFirst)) {
|
||||
switch (status.code()) {
|
||||
case absl::StatusCode::kInternal:
|
||||
error_code = kInternalError;
|
||||
break;
|
||||
case absl::StatusCode::kInvalidArgument:
|
||||
error_code = kInvalidArgumentError;
|
||||
break;
|
||||
case absl::StatusCode::kNotFound:
|
||||
error_code = kNotFoundError;
|
||||
break;
|
||||
default:
|
||||
error_code = kError;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Creates the TfLiteSupportError with the appropriate error
|
||||
// TfLiteSupportErrorCode and message. TfLiteErrorCode has a one to one
|
||||
// mapping with TfLiteSupportStatus starting from the value 1(kError) and
|
||||
// hence will be correctly initialized if directly cast from the integer code
|
||||
// derived from TfLiteSupportStatus stored in payload. TfLiteErrorCode omits
|
||||
// kOk = 0 of TfLiteSupportStatus.
|
||||
//
|
||||
// Stores a string including absl status code and message(if non empty) as the
|
||||
// error message See
|
||||
// https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h#L514
|
||||
// for explanation. absl::Status::message() can also be used but not always
|
||||
// guaranteed to be non empty.
|
||||
CreateTfLiteSupportError(
|
||||
static_cast<TfLiteSupportErrorCode>(error_code),
|
||||
status.ToString(absl::StatusToStringMode::kWithNoExtraData).c_str(),
|
||||
error);
|
||||
}
|
||||
|
||||
} // namespace support
|
||||
} // namespace tflite
|
57
third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
vendored
Normal file
57
third_party/tflite_support/src/tensorflow_lite_support/c/common_utils.h
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_COMMON_UTILS_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_COMMON_UTILS_H_
|
||||
|
||||
#include "absl/status/status.h" // from @com_google_absl
|
||||
#include "tensorflow_lite_support/c/common.h"
|
||||
|
||||
// Utils for Conversion of absl::Status to TfLiteError
|
||||
// -----------------------------------------------------------------
|
||||
// Meant to be used with task C apis.
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
||||
// Creates a TfLiteSupportError with a TfLiteSupportErrorCode and message.
|
||||
void CreateTfLiteSupportError(enum TfLiteSupportErrorCode code,
|
||||
const char* message,
|
||||
TfLiteSupportError** error);
|
||||
|
||||
// Creates a TfLiteSupportError from absl::Status and passes it back as a
|
||||
// parameter which is a pointer to the error pointer.
|
||||
//
|
||||
// Example Usage With Image Classifier
|
||||
//
|
||||
// APIs: TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
|
||||
// const TfLiteImageClassifierOptions* options,
|
||||
// TfLiteSupportError **error) {
|
||||
// // Necessary checks
|
||||
// tflite::support::StatusOr<std::unique_ptr<ImageClassifier>> classifier_status
|
||||
// = // Call to create Cpp Image Classifier.
|
||||
// if (classifier_status.ok()) {
|
||||
// Code to return classifier
|
||||
// } else {
|
||||
// ::tflite::support::CreateTfLiteSupportErrorWithStatus(classifier_status.status(),
|
||||
// error);
|
||||
// return nullptr;
|
||||
// }
|
||||
//}
|
||||
void CreateTfLiteSupportErrorWithStatus(const absl::Status& status,
|
||||
TfLiteSupportError** error);
|
||||
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_COMMON_UTILS_H_
|
9
third_party/tflite_support/src/tensorflow_lite_support/c/task/core/BUILD
vendored
Normal file
9
third_party/tflite_support/src/tensorflow_lite_support/c/task/core/BUILD
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow_lite_support:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "base_options",
|
||||
hdrs = ["base_options.h"],
|
||||
)
|
73
third_party/tflite_support/src/tensorflow_lite_support/c/task/core/base_options.h
vendored
Normal file
73
third_party/tflite_support/src/tensorflow_lite_support/c/task/core/base_options.h
vendored
Normal file
@ -0,0 +1,73 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_CORE_BASE_OPTIONS_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_CORE_BASE_OPTIONS_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
// Defines C Structs for Base Options Shared by all tasks.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// Holds cpu settings.
|
||||
typedef struct TfLiteCpuSettings {
|
||||
// Specifies the number of threads to be used for TFLite
|
||||
// ops that support multi-threading when running inference with CPU.
|
||||
// num_threads should be greater than 0 or equal to -1. Setting num_threads to
|
||||
// -1 has the effect to let TFLite runtime set the value.
|
||||
int num_threads;
|
||||
} TfLiteCpuSettings;
|
||||
|
||||
// Holds settings for one possible acceleration configuration.
|
||||
typedef struct TfLiteComputeSettings {
|
||||
// Holds cpu settings
|
||||
TfLiteCpuSettings cpu_settings;
|
||||
} TfLiteComputeSettings;
|
||||
|
||||
// Represents external files used by the Task APIs (e.g. TF Lite Model File).
|
||||
// For now you can only specify the path of the file using file_path:
|
||||
// In future other sources may be supported.
|
||||
typedef struct TfLiteExternalFile {
|
||||
// The path to the file to open.
|
||||
const char* file_path;
|
||||
// Additional option for byte data when it's supported.
|
||||
} TfLiteExternalFile;
|
||||
|
||||
// Holds the base options that is used for creation of any type of task. It has
|
||||
// fields withh important information acceleration configuration, tflite model
|
||||
// source etc.
|
||||
// This struct must be zero initialized before setting any options as this
|
||||
// will result in seg faults.
|
||||
typedef struct TfLiteBaseOptions {
|
||||
// The external model file, as a single standalone TFLite file. It could be
|
||||
// packed with TFLite Model Metadata[1] and associated files if exist. Fail to
|
||||
// provide the necessary metadata and associated files might result in errors.
|
||||
// Check the documentation for each task about the specific requirement.
|
||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||
TfLiteExternalFile model_file;
|
||||
|
||||
// Holds settings for one possible acceleration configuration
|
||||
// including.cpu/gpu settings. Please see documentation of
|
||||
// TfLiteComputeSettings and its members for more details.
|
||||
TfLiteComputeSettings compute_settings;
|
||||
} TfLiteBaseOptions;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_CORE_BASE_OPTIONS_H_
|
42
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/BUILD
vendored
Normal file
42
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/BUILD
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow_lite_support:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "category",
|
||||
hdrs = ["category.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "classification_result",
|
||||
srcs = [
|
||||
"classification_result.cc",
|
||||
],
|
||||
hdrs = ["classification_result.h"],
|
||||
deps = [
|
||||
":category",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bounding_box",
|
||||
hdrs = ["bounding_box.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "classification_options",
|
||||
hdrs = ["classification_options.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "detection_result",
|
||||
srcs = [
|
||||
"detection_result.cc",
|
||||
],
|
||||
hdrs = ["detection_result.h"],
|
||||
deps = [
|
||||
":bounding_box",
|
||||
":category",
|
||||
],
|
||||
)
|
45
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/bounding_box.h
vendored
Normal file
45
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/bounding_box.h
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_BOUNDING_BOX_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_BOUNDING_BOX_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
// Defines C Struct for Bounding Box Shared by Vision Tasks.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// Holds the region of interest used for image classification.
|
||||
typedef struct TfLiteBoundingBox {
|
||||
// The X coordinate of the top-left corner, in pixels.
|
||||
int origin_x;
|
||||
|
||||
// The Y coordinate of the top-left corner, in pixels.
|
||||
int origin_y;
|
||||
|
||||
// The width of the bounding box, in pixels.
|
||||
int width;
|
||||
|
||||
// The height of the bounding box, in pixels.
|
||||
int height;
|
||||
} TfLiteBoundingBox;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_BOUNDING_BOX_H_
|
49
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/category.h
vendored
Normal file
49
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/category.h
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CATEGORY_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CATEGORY_H_
|
||||
|
||||
// Defines C structure for a Category which encapsulates a single predicted
|
||||
// class.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// A single predicted class.
|
||||
typedef struct TfLiteCategory {
|
||||
// The index of the class in the corresponding label map, usually packed in
|
||||
// the TFLite Model Metadata [1].
|
||||
//
|
||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||
int index;
|
||||
|
||||
// The score for this class e.g. (but not necessarily) a probability in [0,1].
|
||||
float score;
|
||||
|
||||
// A human readable name of the class filled from the label map.
|
||||
char* display_name;
|
||||
// An ID for the class, not necessarily human-readable (e.g. a Google
|
||||
// Knowledge Graph ID [1]), filled from the label map.
|
||||
//
|
||||
// [1]: https://developers.google.com/knowledge-graph
|
||||
char* label;
|
||||
} TfLiteCategory;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_CATEGORY_H_
|
66
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_options.h
vendored
Normal file
66
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_options.h
vendored
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CLASSIFICATION_OPTIONS_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CLASSIFICATION_OPTIONS_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
// Defines C Struct for Classification Options Shared by All Classification
|
||||
// Tasks.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// Holds pointer to array of C strings and length for looping through the array.
|
||||
typedef struct TfLiteStringArrayOption {
|
||||
// Length of list. length can be used to loop through list.
|
||||
int length;
|
||||
|
||||
// Array of C strings.
|
||||
char** list;
|
||||
} TfLiteStringArrayOption;
|
||||
|
||||
// Holds settings for any single classification task.
|
||||
typedef struct TfLiteClassificationOptions {
|
||||
// Optional denylist of class labels. If non NULL, classifications whose
|
||||
// class label is in this set will be filtered out. Duplicate or unknown
|
||||
// class labels are ignored. Mutually exclusive with label_allowlist.
|
||||
TfLiteStringArrayOption label_denylist;
|
||||
|
||||
// Optional allowlist of class labels. If non-empty, classifications whose
|
||||
// class label is not in this set will be filtered out. Duplicate or unknown
|
||||
// class labels are ignored. Mutually exclusive with label_denylist.
|
||||
TfLiteStringArrayOption label_allowlist;
|
||||
|
||||
// The locale to use for display names specified through the TFLite Model
|
||||
// Metadata, if any. Defaults to English.
|
||||
char* display_names_local;
|
||||
|
||||
// The maximum number of top-scored classification results to return. If < 0,
|
||||
// all available results will be returned. If 0, an invalid argument error is
|
||||
// returned. Defaults to -1.
|
||||
int max_results;
|
||||
|
||||
// Score threshold, overrides the ones provided in the model metadata
|
||||
// (if any). Results below this value are rejected.
|
||||
float score_threshold;
|
||||
} TfLiteClassificationOptions;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CLASSIFICATION_OPTIONS_H_
|
46
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc
vendored
Normal file
46
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.cc
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow_lite_support/c/task/processor/classification_result.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
void TfLiteClassificationResultDelete(
|
||||
TfLiteClassificationResult* classification_result) {
|
||||
for (int head = 0; head < classification_result->size; ++head) {
|
||||
TfLiteClassifications classifications =
|
||||
classification_result->classifications[head];
|
||||
for (int rank = 0; rank < classifications.size; ++rank) {
|
||||
// `strdup` obtains memory using `malloc` and the memory needs to be
|
||||
// released using `free`.
|
||||
free(classifications.categories[rank].display_name);
|
||||
free(classifications.categories[rank].label);
|
||||
}
|
||||
|
||||
delete[] classifications.categories;
|
||||
}
|
||||
|
||||
delete[] classification_result->classifications;
|
||||
delete classification_result;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
63
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.h
vendored
Normal file
63
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/classification_result.h
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CLASSIFICATION_RESULT_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_CLASSIFICATION_RESULT_H_
|
||||
|
||||
#include "tensorflow_lite_support/c/task/processor/category.h"
|
||||
|
||||
// Defines C structure for Classification Results and associated helper methods.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// List of predicted classes (aka labels) for a given image classifier head.
|
||||
typedef struct TfLiteClassifications {
|
||||
// The index of the image classifier head these classes refer to. This is
|
||||
// useful for multi-head models.
|
||||
int head_index;
|
||||
|
||||
// Number of predicted classes which can be used to traverse the array of
|
||||
// predicted classes.
|
||||
int size;
|
||||
|
||||
// The array of predicted classes, usually sorted by descending scores (e.g.
|
||||
// from high to low probability). Since this array is dynamically allocated,
|
||||
// use size to traverse through the array.
|
||||
TfLiteCategory* categories;
|
||||
} TfLiteClassifications;
|
||||
|
||||
// Holds Image Classification results.
|
||||
// Contains one set of results per image classifier head.
|
||||
typedef struct TfLiteClassificationResult {
|
||||
// Number of predicted classes which can be used to traverse the array of
|
||||
// predicted classes.
|
||||
int size;
|
||||
|
||||
// Array of image classifier results per image classifier head. This array can
|
||||
// have any number of results. size holds the size of this array. size should
|
||||
// be used to traverse this array.
|
||||
TfLiteClassifications* classifications;
|
||||
} TfLiteClassificationResult;
|
||||
|
||||
// Frees up the ClassificationResult Structure.
|
||||
void TfLiteClassificationResultDelete(
|
||||
TfLiteClassificationResult* classification_result);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_CLASSIFICATION_RESULT_H_
|
65
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/detection_result.h
vendored
Normal file
65
third_party/tflite_support/src/tensorflow_lite_support/c/task/processor/detection_result.h
vendored
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_DETECTION_RESULT_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_PROCESSOR_DETECTION_RESULT_H_
|
||||
|
||||
#include "tensorflow_lite_support/c/task/processor/bounding_box.h"
|
||||
#include "tensorflow_lite_support/c/task/processor/category.h"
|
||||
|
||||
// Defines C structure for Object Detection Results and associated helper
|
||||
// methods.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// Bounding box and list of predicted classes (aka labels) for a detected
|
||||
// object.
|
||||
typedef struct TfLiteDetection {
|
||||
// The bounding box of the detected object.
|
||||
TfLiteBoundingBox bounding_box;
|
||||
|
||||
// The array of predicted classes for the object detection represented by an
|
||||
// instance of TfLiteDetection, usually sorted by descending scores (e.g. from
|
||||
// high to low probability). Since this array is dynamically allocated, use
|
||||
// size to traverse through the array.
|
||||
TfLiteCategory* categories;
|
||||
|
||||
// Number of detectd objects be used to traverse the array of the detected
|
||||
// objects.
|
||||
int size;
|
||||
} TfLiteDetection;
|
||||
|
||||
// Holds Object Detection results.
|
||||
// Contains one set of results per detected object.
|
||||
typedef struct TfLiteDetectionResult {
|
||||
// Number of detectd objects be used to traverse the array of the detected
|
||||
// objects.
|
||||
int size;
|
||||
|
||||
// Array of results per detected object. This array can
|
||||
// have any number of results. size holds the size of this array. size should
|
||||
// be used to traverse this array.
|
||||
TfLiteDetection* detections;
|
||||
} TfLiteDetectionResult;
|
||||
|
||||
// Frees up the DetectionResult Structure.
|
||||
void TfLiteDetectionResultDelete(TfLiteDetectionResult* detection_result);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_DETECTION_H_
|
79
third_party/tflite_support/src/tensorflow_lite_support/c/task/text/BUILD
vendored
Normal file
79
third_party/tflite_support/src/tensorflow_lite_support/c/task/text/BUILD
vendored
Normal file
@ -0,0 +1,79 @@
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl",
|
||||
"cc_library_with_tflite",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow_lite_support:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"bert_nl_classifier.h",
|
||||
"nl_classifier.h",
|
||||
"nl_classifier_common.h",
|
||||
"bert_question_answerer.h",
|
||||
])
|
||||
|
||||
cc_library_with_tflite(
|
||||
name = "nl_classifier",
|
||||
srcs = [
|
||||
"nl_classifier.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"nl_classifier.h",
|
||||
"nl_classifier_common.h",
|
||||
],
|
||||
tflite_deps = [
|
||||
"//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier",
|
||||
],
|
||||
deps = [
|
||||
":nl_classifier_common",
|
||||
"//tensorflow_lite_support/cc/task/core:category",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library_with_tflite(
|
||||
name = "bert_nl_classifier",
|
||||
srcs = [
|
||||
"bert_nl_classifier.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"bert_nl_classifier.h",
|
||||
"nl_classifier_common.h",
|
||||
],
|
||||
tflite_deps = [
|
||||
"//tensorflow_lite_support/cc/task/text:bert_nl_classifier",
|
||||
],
|
||||
deps = [
|
||||
":nl_classifier_common",
|
||||
"//tensorflow_lite_support/cc/task/core:category",
|
||||
"//tensorflow_lite_support/cc/task/text/proto:bert_nl_classifier_options_proto_inc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nl_classifier_common",
|
||||
srcs = [
|
||||
"nl_classifier_common.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"nl_classifier_common.h",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library_with_tflite(
|
||||
name = "bert_question_answerer",
|
||||
srcs = [
|
||||
"bert_question_answerer.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"bert_question_answerer.h",
|
||||
],
|
||||
tflite_deps = [
|
||||
"//tensorflow_lite_support/cc/task/text:bert_question_answerer",
|
||||
"//tensorflow_lite_support/cc/task/text:question_answerer",
|
||||
],
|
||||
)
|
93
third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
vendored
Normal file
93
third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.cc
vendored
Normal file
@ -0,0 +1,93 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow_lite_support/c/task/text/bert_nl_classifier.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/string_view.h" // from @com_google_absl
|
||||
#include "tensorflow_lite_support/cc/task/core/category.h"
|
||||
#include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h"
|
||||
#include "tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options_proto_inc.h"
|
||||
|
||||
namespace {
|
||||
using CategoryCpp = ::tflite::task::core::Category;
|
||||
using BertNLClassifierCpp = ::tflite::task::text::BertNLClassifier;
|
||||
using BertNLClassifierOptionsCpp =
|
||||
::tflite::task::text::BertNLClassifierOptions;
|
||||
|
||||
const TfLiteBertNLClassifierOptions kBertNLClassifierOptionsDefault = {128};
|
||||
} // namespace
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
struct TfLiteBertNLClassifier {
|
||||
std::unique_ptr<BertNLClassifierCpp> impl;
|
||||
};
|
||||
|
||||
TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions(
|
||||
const char* model_path,
|
||||
const TfLiteBertNLClassifierOptions* options) {
|
||||
BertNLClassifierOptionsCpp cc_options;
|
||||
|
||||
cc_options.mutable_base_options()->mutable_model_file()->set_file_name(
|
||||
model_path);
|
||||
auto classifier_status = BertNLClassifierCpp::CreateFromOptions(cc_options);
|
||||
|
||||
if (classifier_status.ok()) {
|
||||
return new TfLiteBertNLClassifier{
|
||||
.impl = std::unique_ptr<BertNLClassifierCpp>(
|
||||
dynamic_cast<BertNLClassifierCpp*>(
|
||||
classifier_status.value().release()))};
|
||||
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path) {
|
||||
return TfLiteBertNLClassifierCreateFromOptions(
|
||||
model_path, &kBertNLClassifierOptionsDefault);
|
||||
}
|
||||
|
||||
Categories* TfLiteBertNLClassifierClassify(
|
||||
const TfLiteBertNLClassifier* classifier,
|
||||
const char* text) {
|
||||
std::vector<CategoryCpp> results =
|
||||
|
||||
classifier->impl->Classify(absl::string_view(text).data());
|
||||
size_t size = results.size();
|
||||
auto* categories = new Category[size];
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
categories[i].text = strdup(results[i].class_name.c_str());
|
||||
categories[i].score = results[i].score;
|
||||
}
|
||||
|
||||
auto* c_categories = new Categories;
|
||||
c_categories->size = size;
|
||||
c_categories->categories = categories;
|
||||
return c_categories;
|
||||
}
|
||||
|
||||
void TfLiteBertNLClassifierDelete(TfLiteBertNLClassifier* classifier) {
|
||||
delete classifier;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
70
third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
vendored
Normal file
70
third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_nl_classifier.h
vendored
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_NL_CLASSIFIER_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_NL_CLASSIFIER_H_
|
||||
|
||||
#include "tensorflow_lite_support/c/task/text/nl_classifier_common.h"
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for BertNLClassifier.
|
||||
//
|
||||
// Usage:
|
||||
// // Create the model and interpreter options.
|
||||
// TfLiteBertNLClassifier* classifier =
|
||||
// TfLiteBertNLClassifierCreate("/path/to/model.tflite");
|
||||
//
|
||||
// // Classification.
|
||||
// Categories* categories = TfLiteBertNLClassifierClassify(classifier,
|
||||
// question);
|
||||
//
|
||||
// // Dispose of the API object.
|
||||
// TfLiteBertNLClassifierDelete(classifier);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
typedef struct TfLiteBertNLClassifier TfLiteBertNLClassifier;
|
||||
|
||||
typedef struct TfLiteBertNLClassifierOptions {
|
||||
// Max number of tokens to pass to the model.
|
||||
//
|
||||
// Deprecated: max_seq_len is now read from the model (i.e. input tensor size)
|
||||
// automatically.
|
||||
int max_seq_len;
|
||||
} TfLiteBertNLClassifierOptions;
|
||||
|
||||
// Creates TfLiteBertNLClassifier from model path and options, returns nullptr
|
||||
// if the file doesn't exist or is not a well formatted TFLite model path.
|
||||
TfLiteBertNLClassifier* TfLiteBertNLClassifierCreateFromOptions(
|
||||
const char* model_path,
|
||||
const TfLiteBertNLClassifierOptions* options);
|
||||
|
||||
// Creates TfLiteBertNLClassifier from model path and default options, returns
|
||||
// nullptr if the file doesn't exist or is not a well formatted TFLite model
|
||||
// path.
|
||||
TfLiteBertNLClassifier* TfLiteBertNLClassifierCreate(const char* model_path);
|
||||
|
||||
// Invokes the encapsulated TFLite model and classifies the input text.
|
||||
Categories* TfLiteBertNLClassifierClassify(
|
||||
const TfLiteBertNLClassifier* classifier,
|
||||
const char* text);
|
||||
|
||||
void TfLiteBertNLClassifierDelete(TfLiteBertNLClassifier* classifier);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_NL_CLASSIFIER_H_
|
@ -13,45 +13,48 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h"
|
||||
#include "tensorflow_lite_support/c/task/text/bert_question_answerer.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"
|
||||
#include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h"
|
||||
#include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h"
|
||||
#include "tensorflow_lite_support/cc/task/text/question_answerer.h"
|
||||
|
||||
using BertQuestionAnswererCPP = ::tflite::task::text::qa::BertQuestionAnswerer;
|
||||
using QaAnswerCPP = ::tflite::task::text::qa::QaAnswer;
|
||||
namespace {
|
||||
using BertQuestionAnswererCpp = ::tflite::task::text::BertQuestionAnswerer;
|
||||
using QaAnswerCpp = ::tflite::task::text::QaAnswer;
|
||||
} // namespace
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
struct BertQuestionAnswerer {
|
||||
std::unique_ptr<BertQuestionAnswererCPP> impl;
|
||||
struct TfLiteBertQuestionAnswerer {
|
||||
std::unique_ptr<BertQuestionAnswererCpp> impl;
|
||||
};
|
||||
|
||||
BertQuestionAnswerer* BertQuestionAnswererFromFile(const char* model_path) {
|
||||
TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate(
|
||||
const char* model_path) {
|
||||
auto bert_qa_status =
|
||||
BertQuestionAnswererCPP::CreateFromFile(std::string(model_path));
|
||||
BertQuestionAnswererCpp::CreateFromFile(std::string(model_path));
|
||||
if (bert_qa_status.ok()) {
|
||||
return new BertQuestionAnswerer{
|
||||
.impl = std::unique_ptr<BertQuestionAnswererCPP>(
|
||||
dynamic_cast<BertQuestionAnswererCPP*>(
|
||||
return new TfLiteBertQuestionAnswerer{
|
||||
.impl = std::unique_ptr<BertQuestionAnswererCpp>(
|
||||
dynamic_cast<BertQuestionAnswererCpp*>(
|
||||
bert_qa_status.value().release()))};
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
QaAnswers* BertQuestionAnswererAnswer(
|
||||
const BertQuestionAnswerer* question_answerer,
|
||||
TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer(
|
||||
const TfLiteBertQuestionAnswerer* question_answerer,
|
||||
const char* context,
|
||||
const char* question) {
|
||||
std::vector<QaAnswerCPP> answers = question_answerer->impl->Answer(
|
||||
std::vector<QaAnswerCpp> answers = question_answerer->impl->Answer(
|
||||
absl::string_view(context).data(), absl::string_view(question).data());
|
||||
size_t size = answers.size();
|
||||
auto* qa_answers = new QaAnswer[size];
|
||||
auto* qa_answers = new TfLiteQaAnswer[size];
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
qa_answers[i].start = answers[i].pos.start;
|
||||
@ -60,17 +63,23 @@ QaAnswers* BertQuestionAnswererAnswer(
|
||||
qa_answers[i].text = strdup(answers[i].text.c_str());
|
||||
}
|
||||
|
||||
auto* c_answers = new QaAnswers;
|
||||
auto* c_answers = new TfLiteQaAnswers;
|
||||
c_answers->size = size;
|
||||
c_answers->answers = qa_answers;
|
||||
return c_answers;
|
||||
}
|
||||
|
||||
void BertQuestionAnswererDelete(BertQuestionAnswerer* bert_question_answerer) {
|
||||
void TfLiteBertQuestionAnswererDelete(
|
||||
TfLiteBertQuestionAnswerer* bert_question_answerer) {
|
||||
delete bert_question_answerer;
|
||||
}
|
||||
|
||||
void BertQuestionAnswererQaAnswersDelete(QaAnswers* qa_answers) {
|
||||
void TfLiteQaAnswersDelete(TfLiteQaAnswers* qa_answers) {
|
||||
for (int i = 0; i < qa_answers->size; i++) {
|
||||
// `strdup` obtains memory using `malloc` and the memory needs to be
|
||||
// released using `free`.
|
||||
free(qa_answers->answers[i].text);
|
||||
}
|
||||
delete[] qa_answers->answers;
|
||||
delete qa_answers;
|
||||
}
|
74
third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
vendored
Normal file
74
third_party/tflite_support/src/tensorflow_lite_support/c/task/text/bert_question_answerer.h
vendored
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_QUESTION_ANSWERER_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_QUESTION_ANSWERER_H_
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for BertQuestionAnswerer.
|
||||
//
|
||||
// Usage:
|
||||
// <pre><code>
|
||||
// // Create the model and interpreter options.
|
||||
// TfLiteBertQuestionAnswerer* qa_answerer =
|
||||
// TfLiteBertQuestionAnswererCreate("/path/to/model.tflite");
|
||||
//
|
||||
// // Answer a question.
|
||||
// TfLiteQaAnswers* answers = TfLiteBertQuestionAnswererAnswer(qa_answerer,
|
||||
// question);
|
||||
//
|
||||
// // Dispose of the API and QaAnswers objects.
|
||||
// TfLiteBertQuestionAnswererDelete(qa_answerer);
|
||||
// TfLiteQaAnswersDelete(answers);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
typedef struct TfLiteBertQuestionAnswerer TfLiteBertQuestionAnswerer;
|
||||
|
||||
typedef struct TfLiteQaAnswer {
|
||||
int start;
|
||||
int end;
|
||||
float logit;
|
||||
char* text;
|
||||
} TfLiteQaAnswer;
|
||||
|
||||
typedef struct TfLiteQaAnswers {
|
||||
int size;
|
||||
TfLiteQaAnswer* answers;
|
||||
} TfLiteQaAnswers;
|
||||
|
||||
// Creates TfLiteBertQuestionAnswerer from model path, returns nullptr if the
|
||||
// file doesn't exist or is not a well formatted TFLite model path.
|
||||
TfLiteBertQuestionAnswerer* TfLiteBertQuestionAnswererCreate(
|
||||
const char* model_path);
|
||||
|
||||
// Invokes the encapsulated TFLite model and answers a question based on
|
||||
// context.
|
||||
TfLiteQaAnswers* TfLiteBertQuestionAnswererAnswer(
|
||||
const TfLiteBertQuestionAnswerer* question_answerer,
|
||||
const char* context,
|
||||
const char* question);
|
||||
|
||||
void TfLiteBertQuestionAnswererDelete(
|
||||
TfLiteBertQuestionAnswerer* bert_question_answerer);
|
||||
|
||||
void TfLiteQaAnswersDelete(TfLiteQaAnswers* qa_answers);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_BERT_QUESTION_ANSWERER_H_
|
@ -13,31 +13,33 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h"
|
||||
#include "tensorflow_lite_support/c/task/text/nl_classifier.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/string_view.h" // from @com_google_absl
|
||||
#include "tensorflow_lite_support/cc/task/core/category.h"
|
||||
#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
|
||||
|
||||
using CategoryCPP = ::tflite::task::core::Category;
|
||||
using NLClassifierCPP = ::tflite::task::text::nlclassifier::NLClassifier;
|
||||
using NLClassifierOptionsCPP =
|
||||
namespace {
|
||||
using CategoryCpp = ::tflite::task::core::Category;
|
||||
using NLClassifierCpp = ::tflite::task::text::nlclassifier::NLClassifier;
|
||||
using NLClassifierOptionsCpp =
|
||||
::tflite::task::text::nlclassifier::NLClassifierOptions;
|
||||
} // namespace
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
struct NLClassifier {
|
||||
std::unique_ptr<NLClassifierCPP> impl;
|
||||
struct TfLiteNLClassifier {
|
||||
std::unique_ptr<NLClassifierCpp> impl;
|
||||
};
|
||||
|
||||
NLClassifier* NLClassifierFromFileAndOptions(
|
||||
TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions(
|
||||
const char* model_path,
|
||||
const NLClassifierOptions* options) {
|
||||
auto classifier_status = NLClassifierCPP::CreateFromFileAndOptions(
|
||||
const TfLiteNLClassifierOptions* options) {
|
||||
auto classifier_status = NLClassifierCpp::CreateFromFileAndOptions(
|
||||
std::string(model_path),
|
||||
{
|
||||
.input_tensor_index = options->input_tensor_index,
|
||||
@ -57,17 +59,17 @@ NLClassifier* NLClassifierFromFileAndOptions(
|
||||
});
|
||||
|
||||
if (classifier_status.ok()) {
|
||||
return new NLClassifier{
|
||||
.impl = std::unique_ptr<NLClassifierCPP>(dynamic_cast<NLClassifierCPP*>(
|
||||
return new TfLiteNLClassifier{
|
||||
.impl = std::unique_ptr<NLClassifierCpp>(dynamic_cast<NLClassifierCpp*>(
|
||||
classifier_status.value().release()))};
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Categories* NLClassifierClassify(const NLClassifier* classifier,
|
||||
const char* text) {
|
||||
std::vector<CategoryCPP> results =
|
||||
Categories* TfLiteNLClassifierClassify(const TfLiteNLClassifier* classifier,
|
||||
const char* text) {
|
||||
std::vector<CategoryCpp> results =
|
||||
classifier->impl->Classify(absl::string_view(text).data());
|
||||
size_t size = results.size();
|
||||
auto* categories = new Category[size];
|
||||
@ -83,7 +85,7 @@ Categories* NLClassifierClassify(const NLClassifier* classifier,
|
||||
return c_categories;
|
||||
}
|
||||
|
||||
void NLClassifierDelete(NLClassifier* classifier) {
|
||||
void TfLiteNLClassifierDelete(TfLiteNLClassifier* classifier) {
|
||||
delete classifier;
|
||||
}
|
||||
|
64
third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
vendored
Normal file
64
third_party/tflite_support/src/tensorflow_lite_support/c/task/text/nl_classifier.h
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_H_
|
||||
|
||||
#include "tensorflow_lite_support/c/task/text/nl_classifier_common.h"
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for NLClassifier.
|
||||
//
|
||||
// Usage:
|
||||
// // Create the model and interpreter options.
|
||||
// TfLiteNLClassifier* classifier = TfLiteNLClassifierCreate(
|
||||
// "/path/to/model.tflite");
|
||||
//
|
||||
// // Classification.
|
||||
// Categories* categories = TfLiteNLClassifierClassify(classifier, question);
|
||||
//
|
||||
// // Dispose of the API object.
|
||||
// TfLiteNLClassifierDelete(classifier);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
typedef struct TfLiteNLClassifier TfLiteNLClassifier;
|
||||
|
||||
typedef struct TfLiteNLClassifierOptions {
|
||||
int input_tensor_index;
|
||||
int output_score_tensor_index;
|
||||
int output_label_tensor_index;
|
||||
const char* input_tensor_name;
|
||||
const char* output_score_tensor_name;
|
||||
const char* output_label_tensor_name;
|
||||
} TfLiteNLClassifierOptions;
|
||||
|
||||
// Creates TfLiteNLClassifier from model path and options, returns nullptr if
|
||||
// the file doesn't exist or is not a well formatted TFLite model path.
|
||||
TfLiteNLClassifier* TfLiteNLClassifierCreateFromOptions(
|
||||
const char* model_path,
|
||||
const TfLiteNLClassifierOptions* options);
|
||||
|
||||
// Invokes the encapsulated TFLite model and classifies the input text.
|
||||
Categories* TfLiteNLClassifierClassify(const TfLiteNLClassifier* classifier,
|
||||
const char* text);
|
||||
|
||||
void TfLiteNLClassifierDelete(TfLiteNLClassifier* classifier);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_H_
|
@ -13,13 +13,20 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h"
|
||||
#include "tensorflow_lite_support/c/task/text/nl_classifier_common.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
void NLClassifierCategoriesDelete(Categories* categories) {
|
||||
for (int i = 0; i < categories->size; i++) {
|
||||
// `strdup` obtains memory using `malloc` and the memory needs to be
|
||||
// released using `free`.
|
||||
free(categories->categories[i].text);
|
||||
}
|
||||
delete[] categories->categories;
|
||||
delete categories;
|
||||
}
|
@ -12,32 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_COMMON_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_COMMON_H_
|
||||
|
||||
// Common structs shared between NLClassifier APIs
|
||||
//
|
||||
/// // Dispose of the Categories object.
|
||||
/// NLClassifierCategoriesDelete(categories);
|
||||
// C API for the NLClassifier results, Catergory.
|
||||
|
||||
// TODO(b/197355311): deprecate this class and use the unified one with image
|
||||
// and audio.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
struct Category {
|
||||
typedef struct Category {
|
||||
char* text;
|
||||
double score;
|
||||
};
|
||||
} Category;
|
||||
|
||||
struct Categories {
|
||||
typedef struct Categories {
|
||||
int size;
|
||||
struct Category* categories;
|
||||
};
|
||||
Category* categories;
|
||||
} Categories;
|
||||
|
||||
extern void NLClassifierCategoriesDelete(struct Categories* categories);
|
||||
void NLClassifierCategoriesDelete(Categories* categories);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_TEXT_NL_CLASSIFIER_COMMON_H_
|
50
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/BUILD
vendored
Normal file
50
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/BUILD
vendored
Normal file
@ -0,0 +1,50 @@
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow_lite_support:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library_with_tflite(
|
||||
name = "image_classifier",
|
||||
srcs = [
|
||||
"image_classifier.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"image_classifier.h",
|
||||
],
|
||||
tflite_deps = [
|
||||
"//tensorflow_lite_support/cc/task/vision:image_classifier",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow_lite_support/c:common",
|
||||
"//tensorflow_lite_support/c:common_utils",
|
||||
"//tensorflow_lite_support/c/task/core:base_options",
|
||||
"//tensorflow_lite_support/c/task/processor:bounding_box",
|
||||
"//tensorflow_lite_support/c/task/processor:classification_options",
|
||||
"//tensorflow_lite_support/c/task/processor:classification_result",
|
||||
"//tensorflow_lite_support/c/task/vision/core:frame_buffer",
|
||||
"//tensorflow_lite_support/c/task/vision/utils:frame_buffer_cpp_c_utils",
|
||||
"//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc",
|
||||
"//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library_with_tflite(
|
||||
name = "object_detector",
|
||||
hdrs = [
|
||||
"object_detector.h",
|
||||
],
|
||||
tflite_deps = [
|
||||
"//tensorflow_lite_support/cc/task/vision:object_detector",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow_lite_support/c:common",
|
||||
"//tensorflow_lite_support/c/task/core:base_options",
|
||||
"//tensorflow_lite_support/c/task/processor:classification_options",
|
||||
"//tensorflow_lite_support/c/task/processor:detection_result",
|
||||
"//tensorflow_lite_support/c/task/vision/core:frame_buffer",
|
||||
],
|
||||
)
|
13
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/core/BUILD
vendored
Normal file
13
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/core/BUILD
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow_lite_support:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "frame_buffer",
|
||||
hdrs = [
|
||||
"frame_buffer.h",
|
||||
],
|
||||
)
|
81
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/core/frame_buffer.h
vendored
Normal file
81
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/core/frame_buffer.h
vendored
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
// Defines C structs for holding the frame buffer.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// Colorspace formats.
|
||||
enum TfLiteFrameBufferFormat {
|
||||
kRGBA,
|
||||
kRGB,
|
||||
kNV12,
|
||||
kNV21,
|
||||
kYV12,
|
||||
kYV21,
|
||||
kGRAY,
|
||||
kUNKNOWN
|
||||
};
|
||||
|
||||
// FrameBuffer content orientation follows EXIF specification. The name of
|
||||
// each enum value defines the position of the 0th row and the 0th column of
|
||||
// the image content. See http://jpegclub.org/exif_orientation.html for
|
||||
// details.
|
||||
enum TfLiteFrameBufferOrientation {
|
||||
kTopLeft = 1,
|
||||
kTopRight = 2,
|
||||
kBottomRight = 3,
|
||||
kBottomLeft = 4,
|
||||
kLeftTop = 5,
|
||||
kRightTop = 6,
|
||||
kRightBottom = 7,
|
||||
kLeftBottom = 8
|
||||
};
|
||||
|
||||
// Dimension information for the whole frame.
|
||||
struct TfLiteFrameBufferDimension {
|
||||
// The width dimension in pixel unit.
|
||||
int width;
|
||||
// The height dimension in pixel unit.
|
||||
int height;
|
||||
};
|
||||
|
||||
// A `FrameBuffer` provides a view into the provided backing buffer (e.g. camera
|
||||
// frame or still image) with buffer format information. FrameBuffer doesn't
|
||||
// take ownership of the provided backing buffer. The caller is responsible to
|
||||
// manage the backing buffer lifecycle for the lifetime of the FrameBuffer.
|
||||
typedef struct TfLiteFrameBuffer {
|
||||
// Colorspace format of the frame buffer.
|
||||
enum TfLiteFrameBufferFormat format;
|
||||
// Orientation of the frame buffer.
|
||||
enum TfLiteFrameBufferOrientation orientation;
|
||||
// Dimension information for the whole frame.
|
||||
struct TfLiteFrameBufferDimension dimension;
|
||||
// Holds the backing buffer for the frame buffer. Only single planar images
|
||||
// are supported as of now.
|
||||
uint8_t* buffer;
|
||||
} TfLiteFrameBuffer;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_H_
|
236
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
vendored
Normal file
236
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.cc
vendored
Normal file
@ -0,0 +1,236 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow_lite_support/c/task/vision/image_classifier.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow_lite_support/c/common_utils.h"
|
||||
#include "tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.h"
|
||||
#include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
|
||||
#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h"
|
||||
#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h"
|
||||
|
||||
namespace {
|
||||
using ::tflite::support::StatusOr;
|
||||
using ClassificationResultCpp = ::tflite::task::vision::ClassificationResult;
|
||||
using ClassificationsCpp = ::tflite::task::vision::Classifications;
|
||||
using ClassCpp = ::tflite::task::vision::Class;
|
||||
using BoundingBoxCpp = ::tflite::task::vision::BoundingBox;
|
||||
using ImageClassifierCpp = ::tflite::task::vision::ImageClassifier;
|
||||
using ImageClassifierOptionsCpp =
|
||||
::tflite::task::vision::ImageClassifierOptions;
|
||||
using FrameBufferCpp = ::tflite::task::vision::FrameBuffer;
|
||||
using ::tflite::support::TfLiteSupportStatus;
|
||||
|
||||
StatusOr<ImageClassifierOptionsCpp> CreateImageClassifierCppOptionsFromCOptions(
|
||||
const TfLiteImageClassifierOptions* c_options) {
|
||||
if (c_options == nullptr) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Expected non null options."),
|
||||
TfLiteSupportStatus::kInvalidArgumentError);
|
||||
}
|
||||
|
||||
ImageClassifierOptionsCpp cpp_options = {};
|
||||
|
||||
// More file sources can be added in else ifs
|
||||
if (c_options->base_options.model_file.file_path)
|
||||
cpp_options.mutable_base_options()->mutable_model_file()->set_file_name(
|
||||
c_options->base_options.model_file.file_path);
|
||||
|
||||
// c_options->base_options.compute_settings.num_threads is expected to be
|
||||
// set to value > 0 or -1. Otherwise invoking
|
||||
// ImageClassifierCpp::CreateFromOptions() results in a not ok status.
|
||||
cpp_options.mutable_base_options()
|
||||
->mutable_compute_settings()
|
||||
->mutable_tflite_settings()
|
||||
->mutable_cpu_settings()
|
||||
->set_num_threads(
|
||||
c_options->base_options.compute_settings.cpu_settings.num_threads);
|
||||
|
||||
for (int i = 0; i < c_options->classification_options.label_denylist.length;
|
||||
i++)
|
||||
cpp_options.add_class_name_blacklist(
|
||||
c_options->classification_options.label_denylist.list[i]);
|
||||
|
||||
for (int i = 0; i < c_options->classification_options.label_allowlist.length;
|
||||
i++)
|
||||
cpp_options.add_class_name_whitelist(
|
||||
c_options->classification_options.label_allowlist.list[i]);
|
||||
|
||||
// Check needed since setting a nullptr for this field results in a segfault
|
||||
// on invocation of ImageClassifierCpp::CreateFromOptions().
|
||||
if (c_options->classification_options.display_names_local) {
|
||||
cpp_options.set_display_names_locale(
|
||||
c_options->classification_options.display_names_local);
|
||||
}
|
||||
|
||||
// c_options->classification_options.max_results is expected to be set to -1
|
||||
// or any value > 0. Otherwise invoking
|
||||
// ImageClassifierCpp::CreateFromOptions() results in a not ok status.
|
||||
cpp_options.set_max_results(c_options->classification_options.max_results);
|
||||
|
||||
cpp_options.set_score_threshold(
|
||||
c_options->classification_options.score_threshold);
|
||||
|
||||
return cpp_options;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
struct TfLiteImageClassifier {
|
||||
std::unique_ptr<ImageClassifierCpp> impl;
|
||||
};
|
||||
|
||||
TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate() {
|
||||
// Use brace-enclosed initializer list will break the Kokoro test.
|
||||
TfLiteImageClassifierOptions options = {{{0}}};
|
||||
options.classification_options.max_results = -1;
|
||||
options.classification_options.score_threshold = 0.0;
|
||||
options.base_options.compute_settings.cpu_settings.num_threads = -1;
|
||||
return options;
|
||||
}
|
||||
|
||||
TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
|
||||
const TfLiteImageClassifierOptions* options,
|
||||
TfLiteSupportError** error) {
|
||||
StatusOr<ImageClassifierOptionsCpp> cpp_option_status =
|
||||
CreateImageClassifierCppOptionsFromCOptions(options);
|
||||
|
||||
if (!cpp_option_status.ok()) {
|
||||
::tflite::support::CreateTfLiteSupportErrorWithStatus(
|
||||
cpp_option_status.status(), error);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<ImageClassifierCpp>> classifier_status =
|
||||
ImageClassifierCpp::CreateFromOptions(cpp_option_status.value());
|
||||
|
||||
if (classifier_status.ok()) {
|
||||
return new TfLiteImageClassifier{.impl =
|
||||
std::move(classifier_status.value())};
|
||||
} else {
|
||||
::tflite::support::CreateTfLiteSupportErrorWithStatus(
|
||||
classifier_status.status(), error);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteClassificationResult* GetClassificationResultCStruct(
|
||||
const ClassificationResultCpp& classification_result_cpp) {
|
||||
auto c_classifications =
|
||||
new TfLiteClassifications[classification_result_cpp
|
||||
.classifications_size()];
|
||||
|
||||
for (int head = 0; head < classification_result_cpp.classifications_size();
|
||||
++head) {
|
||||
const ClassificationsCpp& classifications =
|
||||
classification_result_cpp.classifications(head);
|
||||
c_classifications[head].head_index = head;
|
||||
|
||||
auto c_categories = new TfLiteCategory[classifications.classes_size()];
|
||||
c_classifications->size = classifications.classes_size();
|
||||
|
||||
for (int rank = 0; rank < classifications.classes_size(); ++rank) {
|
||||
const ClassCpp& classification = classifications.classes(rank);
|
||||
c_categories[rank].index = classification.index();
|
||||
c_categories[rank].score = classification.score();
|
||||
|
||||
if (classification.has_class_name())
|
||||
c_categories[rank].label = strdup(classification.class_name().c_str());
|
||||
else
|
||||
c_categories[rank].label = nullptr;
|
||||
|
||||
if (classification.has_display_name())
|
||||
c_categories[rank].display_name =
|
||||
strdup(classification.display_name().c_str());
|
||||
else
|
||||
c_categories[rank].display_name = nullptr;
|
||||
}
|
||||
c_classifications[head].categories = c_categories;
|
||||
}
|
||||
|
||||
auto c_classification_result = new TfLiteClassificationResult;
|
||||
c_classification_result->classifications = c_classifications;
|
||||
c_classification_result->size =
|
||||
classification_result_cpp.classifications_size();
|
||||
|
||||
return c_classification_result;
|
||||
}
|
||||
|
||||
TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
|
||||
const TfLiteImageClassifier* classifier,
|
||||
const TfLiteFrameBuffer* frame_buffer,
|
||||
const TfLiteBoundingBox* roi,
|
||||
TfLiteSupportError** error) {
|
||||
if (classifier == nullptr) {
|
||||
tflite::support::CreateTfLiteSupportError(
|
||||
kInvalidArgumentError, "Expected non null image classifier.", error);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<FrameBufferCpp>> cpp_frame_buffer_status =
|
||||
::tflite::task::vision::CreateCppFrameBuffer(frame_buffer);
|
||||
if (!cpp_frame_buffer_status.ok()) {
|
||||
tflite::support::CreateTfLiteSupportErrorWithStatus(
|
||||
cpp_frame_buffer_status.status(), error);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
BoundingBoxCpp cc_roi;
|
||||
if (roi == nullptr) {
|
||||
cc_roi.set_width(frame_buffer->dimension.width);
|
||||
cc_roi.set_height(frame_buffer->dimension.height);
|
||||
} else {
|
||||
cc_roi.set_origin_x(roi->origin_x);
|
||||
cc_roi.set_origin_y(roi->origin_y);
|
||||
cc_roi.set_width(roi->width);
|
||||
cc_roi.set_height(roi->height);
|
||||
}
|
||||
|
||||
// fnc_sample(cpp_frame_buffer_status);
|
||||
StatusOr<ClassificationResultCpp> cpp_classification_result_status =
|
||||
classifier->impl->Classify(*std::move(cpp_frame_buffer_status.value()),
|
||||
cc_roi);
|
||||
|
||||
if (!cpp_classification_result_status.ok()) {
|
||||
tflite::support::CreateTfLiteSupportErrorWithStatus(
|
||||
cpp_classification_result_status.status(), error);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return GetClassificationResultCStruct(
|
||||
cpp_classification_result_status.value());
|
||||
}
|
||||
|
||||
TfLiteClassificationResult* TfLiteImageClassifierClassify(
|
||||
const TfLiteImageClassifier* classifier,
|
||||
const TfLiteFrameBuffer* frame_buffer,
|
||||
TfLiteSupportError** error) {
|
||||
return TfLiteImageClassifierClassifyWithRoi(classifier, frame_buffer, nullptr,
|
||||
error);
|
||||
}
|
||||
|
||||
void TfLiteImageClassifierDelete(TfLiteImageClassifier* classifier) {
|
||||
delete classifier;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
214
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
vendored
Normal file
214
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/image_classifier.h
vendored
Normal file
@ -0,0 +1,214 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow_lite_support/c/common.h"
|
||||
#include "tensorflow_lite_support/c/task/core/base_options.h"
|
||||
#include "tensorflow_lite_support/c/task/processor/bounding_box.h"
|
||||
#include "tensorflow_lite_support/c/task/processor/classification_options.h"
|
||||
#include "tensorflow_lite_support/c/task/processor/classification_result.h"
|
||||
#include "tensorflow_lite_support/c/task/vision/core/frame_buffer.h"
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
/// C API for ImageClassifiier.
|
||||
///
|
||||
/// The API leans towards simplicity and uniformity instead of convenience, as
|
||||
/// most usage will be by language-specific wrappers. It provides largely the
|
||||
/// same set of functionality as that of the C++ TensorFlow Lite
|
||||
/// `ImageClassifier` API, but is useful for shared libraries where having
|
||||
/// a stable ABI boundary is important.
|
||||
///
|
||||
/// Usage:
|
||||
/// <pre><code>
|
||||
/// // Create the model
|
||||
/// Using the options initialized with default values returned by
|
||||
/// TfLiteImageClassifierOptionsCreate() makes sure that there will be no
|
||||
/// undefined behaviour due to garbage values in unitialized members.
|
||||
/// TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
|
||||
///
|
||||
/// Set the model file path in options
|
||||
/// options.base_options.model_file.file_path = "/path/to/model.tflite";
|
||||
///
|
||||
/// If need be, set values for any options to customize behaviour.
|
||||
/// options.base_options.compute_settings.cpu_settings.num_threads = 3
|
||||
///
|
||||
/// Create TfLiteImageClassifier using the options:
|
||||
/// If error information is not nedded in case of failure:
|
||||
/// TfLiteImageClassifier* image_classifier =
|
||||
/// TfLiteImageClassifierFromOptions(&options, NULL);
|
||||
///
|
||||
/// If error information is nedded in case of failure:
|
||||
/// TfLiteSupportError* create_error = NULL;
|
||||
/// TfLiteImageClassifier* image_classifier =
|
||||
/// TfLiteImageClassifierFromOptions(&options, &create_error);
|
||||
///
|
||||
/// if (!image_classifier) {
|
||||
/// Handle failure.
|
||||
/// Do something with `create_error`, if requested as illustrated above.
|
||||
/// }
|
||||
///
|
||||
/// Dispose of the create_error object.
|
||||
/// TfLiteSupportErrorDelete(create_error);
|
||||
///
|
||||
/// Classify an image
|
||||
/// TfLiteFrameBuffer frame_buffer = { Initialize with image data }
|
||||
///
|
||||
/// If error information is not nedded in case of failure:
|
||||
/// TfLiteClassificationResult* classification_result =
|
||||
/// TfLiteImageClassifierClassify(image_classifier, &frame_buffer, NULL);
|
||||
///
|
||||
/// If error information is nedded in case of failure:
|
||||
/// TfLiteSupportError* classify_error = NULL;
|
||||
/// TfLiteClassificationResult* classification_result =
|
||||
/// TfLiteImageClassifierClassify(image_classifier, &frame_buffer,
|
||||
/// &classify_error);
|
||||
///
|
||||
/// if (!classification_result) {
|
||||
/// Handle failure.
|
||||
/// Do something with `classify_error`, if requested as illustrated above.
|
||||
/// }
|
||||
///
|
||||
/// Dispose of the classify_error object.
|
||||
/// TfLiteSupportErrorDelete(classify_error);
|
||||
///
|
||||
/// Dispose of the API object.
|
||||
/// TfLiteImageClassifierDelete(image_classifier);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
typedef struct TfLiteImageClassifier TfLiteImageClassifier;
|
||||
|
||||
typedef struct TfLiteImageClassifierOptions {
|
||||
TfLiteClassificationOptions classification_options;
|
||||
TfLiteBaseOptions base_options;
|
||||
} TfLiteImageClassifierOptions;
|
||||
|
||||
// Creates and returns TfLiteImageClassifierOptions initialized with default
|
||||
// values. Default values are as follows:
|
||||
// 1. .classification_options.max_results = -1, which returns all classification
|
||||
// categories by default.
|
||||
// 2. .base_options.compute_settings.tflite_settings.cpu_settings.num_threads =
|
||||
// -1, which makes the TFLite runtime choose the value.
|
||||
// 3. .classification_options.score_threshold = 0
|
||||
// 4. All pointers like .base_options.model_file.file_path,
|
||||
// .base_options.classification_options.display_names_local,
|
||||
// .classification_options.label_allowlist.list,
|
||||
// options.classification_options.label_denylist.list are NULL.
|
||||
// 5. All other integer values are initialized to 0.
|
||||
TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate();
|
||||
|
||||
// Creates TfLiteImageClassifier from options.
|
||||
// .base_options.model_file.file_path in TfLiteImageClassifierOptions should be
|
||||
// set to the path of the tflite model you wish to create the
|
||||
// TfLiteImageClassifier with.
|
||||
// Create TfLiteImageClassifierOptions using
|
||||
// TfLiteImageClassifierOptionsCreate(). If need be, you can change the default
|
||||
// values of options for customizing classification, If options are not created
|
||||
// in the aforementioned way, you have to make sure that all members are
|
||||
// initialized to respective default values and all pointer members are zero
|
||||
// initialized to avoid any undefined behaviour.
|
||||
//
|
||||
// Returns the created image classifier in case of success.
|
||||
// Returns nullptr on failure which happens commonly due to one of the following
|
||||
// issues:
|
||||
// 1. file doesn't exist or is not a well formatted.
|
||||
// 2. options is nullptr.
|
||||
// 3. Both options.classification_options.label_denylist and
|
||||
// options.classification_options.label_allowlist are non empty. These
|
||||
// fields are mutually exclusive.
|
||||
//
|
||||
// The caller can check if an error was encountered by testing if the returned
|
||||
// value of the function is null. If the caller doesn't want the reason for
|
||||
// failure, they can simply pass a NULL for the address of the error pointer as
|
||||
// shown below:
|
||||
//
|
||||
// TfLiteImageClassifier* classifier = TfLiteImageClassifierFromOptions(options,
|
||||
// NULL);
|
||||
//
|
||||
// If the caller wants to be informed of the reason for failure, they must pass
|
||||
// the adress of a pointer of type TfLiteSupportError to the `error` param as
|
||||
// shown below:
|
||||
//
|
||||
// TfLiteSupport *error = NULL:
|
||||
// TfLiteImageClassifier* classifier = TfLiteImageClassifierFromOptions(options,
|
||||
// &error);
|
||||
//
|
||||
// In case of unsuccessful execution, Once the function returns, the error
|
||||
// pointer will point to a struct containing the error information. If error
|
||||
// info is passed back to the caller, it is the responsibility of the caller to
|
||||
// free the error struct by calling the following method defined in common.h:
|
||||
//
|
||||
// TfLiteSupportErrorDelete(error)
|
||||
//
|
||||
TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
|
||||
const TfLiteImageClassifierOptions* options,
|
||||
TfLiteSupportError** error);
|
||||
|
||||
// Invokes the encapsulated TFLite model and classifies the frame_buffer.
|
||||
// Returns a pointer to the created classification result in case of success or
|
||||
// NULL in case of failure. The caller must test the return value to identify
|
||||
// success or failure. If the caller doesn't want the reason for failure, they
|
||||
// can simply pass a NULL for the address of the error pointer as shown below:
|
||||
//
|
||||
// TfLiteClassificationResult* classification_result =
|
||||
// TfLiteImageClassifierClassify(&options, NULL);
|
||||
//
|
||||
// If the caller wants to be informed of the reason for failure, they must pass
|
||||
// the adress of a pointer of type TfLiteSupportError to the `error` param as
|
||||
// shown below:
|
||||
//
|
||||
// TfLiteSupport *error = NULL:
|
||||
// TfLiteImageClassifier* classifier = TfLiteImageClassifierFromOptions(options,
|
||||
// &error);
|
||||
//
|
||||
// In case of unsuccessful execution, Once the function returns, the error
|
||||
// pointer will point to a struct containing the error information. If error
|
||||
// info is passed back to the caller, it is the responsibility of the caller to
|
||||
// free the error struct by calling the following method defined in common.h:
|
||||
//
|
||||
// TfLiteSupportErrorDelete(error)
|
||||
//
|
||||
TfLiteClassificationResult* TfLiteImageClassifierClassify(
|
||||
const TfLiteImageClassifier* classifier,
|
||||
const TfLiteFrameBuffer* frame_buffer,
|
||||
TfLiteSupportError** error);
|
||||
|
||||
// Invokes the encapsulated TFLite model and classifies the region of the
|
||||
// frame_buffer specified by the bounding box. Same as TfLiteImageClassifier*
|
||||
// TfLiteImageClassifierFromOptions(
|
||||
// const TfLiteImageClassifierOptions* options, TfLiteSupportError** error),
|
||||
// except that the
|
||||
// classification is performed based on the input region of interest. Cropping
|
||||
// according to this region of interest is prepended to the pre-processing
|
||||
// operations.
|
||||
TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
|
||||
const TfLiteImageClassifier* classifier,
|
||||
const TfLiteFrameBuffer* frame_buffer,
|
||||
const TfLiteBoundingBox* roi,
|
||||
TfLiteSupportError** error);
|
||||
|
||||
// Disposes off the image classifier.
|
||||
void TfLiteImageClassifierDelete(TfLiteImageClassifier* classifier);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_
|
200
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
vendored
Normal file
200
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/object_detector.h
vendored
Normal file
@ -0,0 +1,200 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow_lite_support/c/common.h"
|
||||
#include "tensorflow_lite_support/c/task/core/base_options.h"
|
||||
#include "tensorflow_lite_support/c/task/processor/classification_options.h"
|
||||
#include "tensorflow_lite_support/c/task/processor/detection_result.h"
|
||||
#include "tensorflow_lite_support/c/task/vision/core/frame_buffer.h"
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
/// C API for Object Detector.
|
||||
///
|
||||
/// The API leans towards simplicity and uniformity instead of convenience, as
|
||||
/// most usage will be by language-specific wrappers. It provides largely the
|
||||
/// same set of functionality as that of the C++ TensorFlow Lite
|
||||
/// `ObjectDetector` API, but is useful for shared libraries where having
|
||||
/// a stable ABI boundary is important.
|
||||
///
|
||||
/// Usage:
|
||||
/// <pre><code>
|
||||
/// // Create the model
|
||||
/// Using the options initialized with default values returned by
|
||||
/// TfLiteObjectDetectorOptionsCreate() makes sure that there will be no
|
||||
/// undefined behaviour due to garbage values in unitialized members.
|
||||
/// TfLiteObjectDetectorOptions options = TfLiteObjectDetectorOptionsCreate();
|
||||
///
|
||||
/// Set the model file path in options
|
||||
/// options.base_options.model_file.file_path = "/path/to/model.tflite";
|
||||
///
|
||||
/// If need be, set values for any options to customize behaviour.
|
||||
/// options.base_options.compute_settings.cpu_settings.num_threads = 3
|
||||
///
|
||||
/// Create TfLiteObjectDetector using the options:
|
||||
/// If error information is not nedded in case of failure:
|
||||
/// TfLiteObjectDetector* object_detector =
|
||||
/// TfLiteObjectDetectorFromOptions(&options, NULL);
|
||||
///
|
||||
/// If error information is nedded in case of failure:
|
||||
/// TfLiteSupportError* create_error = NULL;
|
||||
/// TfLiteObjectDetector* object_detector =
|
||||
/// TfLiteObjectDetectorFromOptions(&options, &create_error);
|
||||
///
|
||||
/// if (!object_detector) {
|
||||
/// Handle failure.
|
||||
/// Do something with `create_error`, if requested as illustrated above.
|
||||
/// }
|
||||
///
|
||||
/// Dispose of the create_error object.
|
||||
/// TfLiteSupportErrorDelete(create_error);
|
||||
///
|
||||
/// Classify an image
|
||||
/// TfLiteFrameBuffer frame_buffer = { Initialize with image data }
|
||||
///
|
||||
/// If error information is not nedded in case of failure:
|
||||
/// TfLiteDetectionResult* detection_result =
|
||||
/// TfLiteObjectDetectorClassify(object_detector, &frame_buffer, NULL);
|
||||
///
|
||||
/// If error information is needed in case of failure:
|
||||
/// TfLiteSupportError* detect_error = NULL;
|
||||
/// TfLiteDetectionResult* detection_result =
|
||||
/// TfLiteObjectDetectorDetect(object_detector, &frame_buffer,
|
||||
/// &detect_error);
|
||||
///
|
||||
/// if (!detection_result) {
|
||||
/// Handle failure.
|
||||
/// Do something with `detection_error`, if requested as illustrated above.
|
||||
/// }
|
||||
///
|
||||
/// Dispose of the detection_error object.
|
||||
/// TfLiteSupportErrorDelete(detection_error);
|
||||
///
|
||||
/// Dispose of the API object.
|
||||
/// TfLiteObjectDetectorOptionsDelete(object_detector);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
typedef struct TfLiteObjectDetector TfLiteObjectDetector;
|
||||
|
||||
typedef struct TfLiteObjectDetectorOptions {
|
||||
TfLiteClassificationOptions classification_options;
|
||||
TfLiteBaseOptions base_options;
|
||||
} TfLiteObjectDetectorOptions;
|
||||
|
||||
// Creates and returns TfLiteObjectDetectorOptions initialized with default
|
||||
// values. Default values are as follows:
|
||||
// 1. .classification_options.max_results = -1, which returns all classification
|
||||
// categories by default.
|
||||
// 2. .base_options.compute_settings.tflite_settings.cpu_settings.num_threads =
|
||||
// -1, which makes the TFLite runtime choose the value.
|
||||
// 3. .classification_options.score_threshold = 0
|
||||
// 4. All pointers like .base_options.model_file.file_path,
|
||||
// .base_options.classification_options.display_names_local,
|
||||
// .classification_options.label_allowlist.list,
|
||||
// options.classification_options.label_denylist.list are NULL.
|
||||
// 5. All other integer values are initialized to 0.
|
||||
TfLiteObjectDetectorOptions TfLiteObjectDetectorOptionsCreate();
|
||||
|
||||
// Creates TfLiteObjectDetector from options.
|
||||
// .base_options.model_file.file_path in TfLiteObjectDetectorOptions should be
|
||||
// set to the path of the tflite model you wish to create the
|
||||
// TfLiteObjectDetector with.
|
||||
// Create TfLiteObjectDetectorOptions using
|
||||
// TfLiteObjectDetectorOptionsCreate(). If need be, you can change the default
|
||||
// values of options for customizing classification, If options are not created
|
||||
// in the aforementioned way, you have to make sure that all members are
|
||||
// initialized to respective default values and all pointer members are zero
|
||||
// initialized to avoid any undefined behaviour.
|
||||
//
|
||||
// Returns the created object detector in case of success.
|
||||
// Returns nullptr on failure which happens commonly due to one of the following
|
||||
// issues:
|
||||
// 1. file doesn't exist or is not a well formatted.
|
||||
// 2. options is nullptr.
|
||||
// 3. Both options.classification_options.label_denylist and
|
||||
// options.classification_options.label_allowlist are non empty. These
|
||||
// fields are mutually exclusive.
|
||||
//
|
||||
// The caller can check if an error was encountered by testing if the returned
|
||||
// value of the function is null. If the caller doesn't want the reason for
|
||||
// failure, they can simply pass a NULL for the address of the error pointer as
|
||||
// shown below:
|
||||
//
|
||||
// TfLiteObjectDetector* detector = TfLiteObjectDetectorFromOptions(options,
|
||||
// NULL);
|
||||
//
|
||||
// If the caller wants to be informed of the reason for failure, they must pass
|
||||
// the adress of a pointer of type TfLiteSupportError to the `error` param as
|
||||
// shown below:
|
||||
//
|
||||
// TfLiteSupport *error = NULL:
|
||||
// TfLiteObjectDetector* classifier = TfLiteObjectDetectorFromOptions(options,
|
||||
// &error);
|
||||
//
|
||||
// In case of unsuccessful execution, Once the function returns, the error
|
||||
// pointer will point to a struct containing the error information. If error
|
||||
// info is passed back to the caller, it is the responsibility of the caller to
|
||||
// free the error struct by calling the following method defined in common.h:
|
||||
//
|
||||
// TfLiteSupportErrorDelete(error)
|
||||
//
|
||||
TfLiteObjectDetector* TfLiteObjectDetectorFromOptions(
|
||||
const TfLiteObjectDetectorOptions* options,
|
||||
TfLiteSupportError** error);
|
||||
|
||||
// Invokes the encapsulated TFLite model and performs object detection on the
|
||||
// frame_buffer. Returns a pointer to the created object detection result result
|
||||
// in case of success or NULL in case of failure. The caller must test the
|
||||
// return value to identify success or failure. If the caller doesn't want the
|
||||
// reason for failure, they can simply pass a NULL for the address of the error
|
||||
// pointer as shown below:
|
||||
//
|
||||
// TfLiteDetectionResult* detection_result =
|
||||
// TfLiteObjectDetectorDetect(&options, NULL);
|
||||
//
|
||||
// If the caller wants to be informed of the reason for failure, they must pass
|
||||
// the adress of a pointer of type TfLiteSupportError to the `error` param as
|
||||
// shown below:
|
||||
//
|
||||
// TfLiteSupport *error = NULL:
|
||||
// TfLiteObjectDetector* detector = TfLiteObjectDetectorFromOptions(options,
|
||||
// &error);
|
||||
//
|
||||
// In case of unsuccessful execution, Once the function returns, the error
|
||||
// pointer will point to a struct containing the error information. If error
|
||||
// info is passed back to the caller, it is the responsibility of the caller to
|
||||
// free the error struct by calling the following method defined in common.h:
|
||||
//
|
||||
// TfLiteSupportErrorDelete(error)
|
||||
//
|
||||
TfLiteDetectionResult* TfLiteObjectDetectorDetect(
|
||||
const TfLiteObjectDetector* detector,
|
||||
const TfLiteFrameBuffer* frame_buffer,
|
||||
TfLiteSupportError** error);
|
||||
|
||||
// Disposes off the object detector.
|
||||
void TfLiteObjectDetectorDelete(TfLiteObjectDetector* detector);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_IMAGE_CLASSIFIER_H_
|
22
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/BUILD
vendored
Normal file
22
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/BUILD
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow_lite_support:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "frame_buffer_cpp_c_utils",
|
||||
srcs = [
|
||||
"frame_buffer_cpp_c_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"frame_buffer_cpp_c_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow_lite_support/c/task/vision/core:frame_buffer",
|
||||
"//tensorflow_lite_support/cc:common",
|
||||
"//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
50
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.cc
vendored
Normal file
50
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.cc
vendored
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.h"
|
||||
|
||||
#include "absl/strings/str_format.h" // from @com_google_absl
|
||||
#include "tensorflow_lite_support/cc/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace task {
|
||||
namespace vision {
|
||||
|
||||
namespace {
|
||||
using FrameBufferCpp = ::tflite::task::vision::FrameBuffer;
|
||||
using ::tflite::support::StatusOr;
|
||||
using ::tflite::support::TfLiteSupportStatus;
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<FrameBufferCpp>> CreateCppFrameBuffer(
|
||||
const TfLiteFrameBuffer* frame_buffer) {
|
||||
if (frame_buffer == nullptr)
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Expected non null frame buffer."),
|
||||
TfLiteSupportStatus::kInvalidArgumentError);
|
||||
|
||||
FrameBufferCpp::Format frame_buffer_format =
|
||||
FrameBufferCpp::Format(frame_buffer->format);
|
||||
|
||||
return CreateFromRawBuffer(
|
||||
frame_buffer->buffer,
|
||||
{frame_buffer->dimension.width, frame_buffer->dimension.height},
|
||||
frame_buffer_format);
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace task
|
||||
} // namespace tflite
|
36
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.h
vendored
Normal file
36
third_party/tflite_support/src/tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.h
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_CPP_C_UTILS_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_CPP_C_UTILS_H_
|
||||
|
||||
#include "tensorflow_lite_support/c/task/vision/core/frame_buffer.h"
|
||||
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
|
||||
|
||||
// Utils for Conversions between C and C++ FrameBuffer
|
||||
// -----------------------------------------------------------------
|
||||
// Meant to be used with vision C apis.
|
||||
|
||||
// Creates the C++ FrameBuffer from the C FrameBuffer
|
||||
namespace tflite {
|
||||
namespace task {
|
||||
namespace vision {
|
||||
|
||||
tflite::support::StatusOr<std::unique_ptr<tflite::task::vision::FrameBuffer>>
|
||||
CreateCppFrameBuffer(const TfLiteFrameBuffer* frame_buffer);
|
||||
|
||||
} // namespace vision
|
||||
} // namespace task
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_C_TASK_VISION_FRAME_BUFFER_CPP_C_UTILS_H_
|
34
third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/BUILD
vendored
Normal file
34
third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/BUILD
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl",
|
||||
"cc_test_with_tflite",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:private",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# To test it with Bazel, run the following command from the terminal of your desktop:
|
||||
# bazel test tensorflow_lite_support/c/test/task/vision:image_classifier_test
|
||||
cc_test_with_tflite(
|
||||
name = "image_classifier_test",
|
||||
srcs = ["image_classifier_test.cc"],
|
||||
data = [
|
||||
"//tensorflow_lite_support/cc/test/testdata/task/vision:test_images",
|
||||
"//tensorflow_lite_support/cc/test/testdata/task/vision:test_models",
|
||||
],
|
||||
tflite_deps = [
|
||||
"//tensorflow_lite_support/c/task/vision:image_classifier",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow_lite_support/c:common",
|
||||
"//tensorflow_lite_support/c/task/processor:classification_result",
|
||||
"//tensorflow_lite_support/c/task/vision/core:frame_buffer",
|
||||
"//tensorflow_lite_support/cc/port:gtest_main",
|
||||
"//tensorflow_lite_support/cc/test:test_utils",
|
||||
"//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils",
|
||||
],
|
||||
)
|
432
third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
vendored
Normal file
432
third_party/tflite_support/src/tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc
vendored
Normal file
@ -0,0 +1,432 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow_lite_support/c/task/vision/image_classifier.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow_lite_support/c/common.h"
|
||||
#include "tensorflow_lite_support/c/task/processor/classification_result.h"
|
||||
#include "tensorflow_lite_support/c/task/vision/core/frame_buffer.h"
|
||||
#include "tensorflow_lite_support/cc/port/gmock.h"
|
||||
#include "tensorflow_lite_support/cc/port/gtest.h"
|
||||
#include "tensorflow_lite_support/cc/port/status_matchers.h"
|
||||
#include "tensorflow_lite_support/cc/test/test_utils.h"
|
||||
#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace task {
|
||||
namespace vision {
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
using ::tflite::support::StatusOr;
|
||||
using ::tflite::task::JoinPath;
|
||||
|
||||
constexpr char kTestDataDirectory[] =
|
||||
"/tensorflow_lite_support/cc/test/testdata/task/"
|
||||
"vision/";
|
||||
// Quantized model.
|
||||
constexpr char kMobileNetQuantizedWithMetadata[] =
|
||||
"mobilenet_v1_0.25_224_quant.tflite";
|
||||
|
||||
StatusOr<ImageData> LoadImage(const char* image_name) {
|
||||
return DecodeImageFromFile(
|
||||
JoinPath("./" /*test src dir*/, kTestDataDirectory, image_name));
|
||||
}
|
||||
|
||||
class ImageClassifierFromOptionsTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(ImageClassifierFromOptionsTest, FailsWithNullOptionsAndError) {
|
||||
TfLiteSupportError* error = nullptr;
|
||||
TfLiteImageClassifier* image_classifier =
|
||||
TfLiteImageClassifierFromOptions(nullptr, &error);
|
||||
|
||||
EXPECT_EQ(image_classifier, nullptr);
|
||||
if (image_classifier)
|
||||
TfLiteImageClassifierDelete(image_classifier);
|
||||
|
||||
ASSERT_NE(error, nullptr);
|
||||
EXPECT_EQ(error->code, kInvalidArgumentError);
|
||||
EXPECT_NE(error->message, nullptr);
|
||||
EXPECT_THAT(error->message, HasSubstr("Expected non null options"));
|
||||
|
||||
TfLiteSupportErrorDelete(error);
|
||||
}
|
||||
|
||||
TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPath) {
|
||||
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
|
||||
TfLiteImageClassifier* image_classifier =
|
||||
TfLiteImageClassifierFromOptions(&options, nullptr);
|
||||
EXPECT_EQ(image_classifier, nullptr);
|
||||
if (image_classifier)
|
||||
TfLiteImageClassifierDelete(image_classifier);
|
||||
}
|
||||
|
||||
TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPathAndError) {
|
||||
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
|
||||
|
||||
TfLiteSupportError* error = nullptr;
|
||||
TfLiteImageClassifier* image_classifier =
|
||||
TfLiteImageClassifierFromOptions(&options, &error);
|
||||
|
||||
EXPECT_EQ(image_classifier, nullptr);
|
||||
if (image_classifier)
|
||||
TfLiteImageClassifierDelete(image_classifier);
|
||||
|
||||
ASSERT_NE(error, nullptr);
|
||||
EXPECT_EQ(error->code, kInvalidArgumentError);
|
||||
EXPECT_NE(error->message, nullptr);
|
||||
EXPECT_THAT(error->message, HasSubstr("`base_options.model_file`"));
|
||||
|
||||
TfLiteSupportErrorDelete(error);
|
||||
}
|
||||
|
||||
TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) {
|
||||
std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
|
||||
kMobileNetQuantizedWithMetadata);
|
||||
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
|
||||
options.base_options.model_file.file_path = model_path.data();
|
||||
TfLiteImageClassifier* image_classifier =
|
||||
TfLiteImageClassifierFromOptions(&options, nullptr);
|
||||
|
||||
EXPECT_NE(image_classifier, nullptr);
|
||||
TfLiteImageClassifierDelete(image_classifier);
|
||||
}
|
||||
|
||||
TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreadsAndError) {
|
||||
std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
|
||||
kMobileNetQuantizedWithMetadata);
|
||||
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
|
||||
options.base_options.model_file.file_path = model_path.data();
|
||||
options.base_options.compute_settings.cpu_settings.num_threads = 3;
|
||||
|
||||
TfLiteSupportError* error = nullptr;
|
||||
TfLiteImageClassifier* image_classifier =
|
||||
TfLiteImageClassifierFromOptions(&options, &error);
|
||||
|
||||
EXPECT_NE(image_classifier, nullptr);
|
||||
EXPECT_EQ(error, nullptr);
|
||||
|
||||
if (image_classifier)
|
||||
TfLiteImageClassifierDelete(image_classifier);
|
||||
if (error)
|
||||
TfLiteSupportErrorDelete(error);
|
||||
}
|
||||
|
||||
TEST_F(ImageClassifierFromOptionsTest,
|
||||
FailsWithClassNameDenyListAndClassNameAllowListAndError) {
|
||||
std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
|
||||
kMobileNetQuantizedWithMetadata);
|
||||
|
||||
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
|
||||
options.base_options.model_file.file_path = model_path.data();
|
||||
|
||||
char* label_denylist[9] = {(char*)"brambling"};
|
||||
options.classification_options.label_denylist.list = label_denylist;
|
||||
options.classification_options.label_denylist.length = 1;
|
||||
|
||||
char* label_allowlist[12] = {(char*)"cheeseburger"};
|
||||
options.classification_options.label_allowlist.list = label_allowlist;
|
||||
options.classification_options.label_allowlist.length = 1;
|
||||
|
||||
TfLiteSupportError* error = nullptr;
|
||||
TfLiteImageClassifier* image_classifier =
|
||||
TfLiteImageClassifierFromOptions(&options, &error);
|
||||
|
||||
EXPECT_EQ(image_classifier, nullptr);
|
||||
if (image_classifier)
|
||||
TfLiteImageClassifierDelete(image_classifier);
|
||||
|
||||
ASSERT_NE(error, nullptr);
|
||||
EXPECT_EQ(error->code, kInvalidArgumentError);
|
||||
EXPECT_NE(error->message, nullptr);
|
||||
EXPECT_THAT(error->message, HasSubstr("mutually exclusive options"));
|
||||
|
||||
TfLiteSupportErrorDelete(error);
|
||||
}
|
||||
|
||||
TEST(ImageClassifierNullClassifierClassifyTest,
|
||||
FailsWithNullImageClassifierAndError) {
|
||||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
|
||||
LoadImage("burger-224.png"));
|
||||
|
||||
TfLiteSupportError* error = nullptr;
|
||||
TfLiteClassificationResult* classification_result =
|
||||
TfLiteImageClassifierClassify(nullptr, nullptr, &error);
|
||||
|
||||
ImageDataFree(&image_data);
|
||||
|
||||
EXPECT_EQ(classification_result, nullptr);
|
||||
if (classification_result)
|
||||
TfLiteClassificationResultDelete(classification_result);
|
||||
|
||||
ASSERT_NE(error, nullptr);
|
||||
EXPECT_EQ(error->code, kInvalidArgumentError);
|
||||
EXPECT_NE(error->message, nullptr);
|
||||
EXPECT_THAT(error->message, HasSubstr("Expected non null image classifier"));
|
||||
|
||||
TfLiteSupportErrorDelete(error);
|
||||
}
|
||||
|
||||
class ImageClassifierClassifyTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
|
||||
kMobileNetQuantizedWithMetadata);
|
||||
|
||||
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
|
||||
options.base_options.model_file.file_path = model_path.data();
|
||||
image_classifier = TfLiteImageClassifierFromOptions(&options, nullptr);
|
||||
ASSERT_NE(image_classifier, nullptr);
|
||||
}
|
||||
|
||||
void TearDown() override { TfLiteImageClassifierDelete(image_classifier); }
|
||||
TfLiteImageClassifier* image_classifier;
|
||||
};
|
||||
|
||||
TEST_F(ImageClassifierClassifyTest, SucceedsWithImageData) {
|
||||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
|
||||
LoadImage("burger-224.png"));
|
||||
|
||||
TfLiteFrameBuffer frame_buffer = {
|
||||
.format = kRGB,
|
||||
.orientation = kTopLeft,
|
||||
.dimension = {.width = image_data.width, .height = image_data.height},
|
||||
.buffer = image_data.pixel_data};
|
||||
|
||||
TfLiteClassificationResult* classification_result =
|
||||
TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr);
|
||||
|
||||
ImageDataFree(&image_data);
|
||||
|
||||
ASSERT_NE(classification_result, nullptr);
|
||||
EXPECT_GE(classification_result->size, 1);
|
||||
EXPECT_NE(classification_result->classifications, nullptr);
|
||||
EXPECT_GE(classification_result->classifications->size, 1);
|
||||
EXPECT_NE(classification_result->classifications->categories, nullptr);
|
||||
EXPECT_EQ(strcmp(classification_result->classifications->categories[0].label,
|
||||
"cheeseburger"),
|
||||
0);
|
||||
EXPECT_GE(classification_result->classifications->categories[0].score, 0.90);
|
||||
|
||||
TfLiteClassificationResultDelete(classification_result);
|
||||
}
|
||||
|
||||
TEST_F(ImageClassifierClassifyTest, FailsWithNullFrameBufferAndError) {
|
||||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
|
||||
LoadImage("burger-224.png"));
|
||||
|
||||
TfLiteSupportError* error = nullptr;
|
||||
TfLiteClassificationResult* classification_result =
|
||||
TfLiteImageClassifierClassify(image_classifier, nullptr, &error);
|
||||
|
||||
ImageDataFree(&image_data);
|
||||
|
||||
EXPECT_EQ(classification_result, nullptr);
|
||||
if (classification_result)
|
||||
TfLiteClassificationResultDelete(classification_result);
|
||||
|
||||
ASSERT_NE(error, nullptr);
|
||||
EXPECT_EQ(error->code, kInvalidArgumentError);
|
||||
EXPECT_NE(error->message, nullptr);
|
||||
EXPECT_THAT(error->message, HasSubstr("Expected non null frame buffer"));
|
||||
|
||||
TfLiteSupportErrorDelete(error);
|
||||
}
|
||||
|
||||
TEST_F(ImageClassifierClassifyTest, FailsWithNullImageDataAndError) {
|
||||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
|
||||
LoadImage("burger-224.png"));
|
||||
|
||||
TfLiteFrameBuffer frame_buffer = {.format = kRGB, .orientation = kTopLeft};
|
||||
|
||||
TfLiteSupportError* error = nullptr;
|
||||
TfLiteClassificationResult* classification_result =
|
||||
TfLiteImageClassifierClassify(image_classifier, &frame_buffer, &error);
|
||||
|
||||
ImageDataFree(&image_data);
|
||||
|
||||
EXPECT_EQ(classification_result, nullptr);
|
||||
if (classification_result)
|
||||
TfLiteClassificationResultDelete(classification_result);
|
||||
|
||||
ASSERT_NE(error, nullptr);
|
||||
EXPECT_EQ(error->code, kInvalidArgumentError);
|
||||
EXPECT_NE(error->message, nullptr);
|
||||
EXPECT_THAT(error->message, HasSubstr("Invalid stride information"));
|
||||
|
||||
TfLiteSupportErrorDelete(error);
|
||||
}
|
||||
|
||||
TEST_F(ImageClassifierClassifyTest, SucceedsWithRoiWithinImageBounds) {
|
||||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
|
||||
LoadImage("burger-224.png"));
|
||||
|
||||
TfLiteFrameBuffer frame_buffer = {
|
||||
.format = kRGB,
|
||||
.orientation = kTopLeft,
|
||||
.dimension = {.width = image_data.width, .height = image_data.height},
|
||||
.buffer = image_data.pixel_data};
|
||||
|
||||
TfLiteBoundingBox bounding_box = {
|
||||
.origin_x = 0, .origin_y = 0, .width = 100, .height = 100};
|
||||
TfLiteSupportError* error = nullptr;
|
||||
TfLiteClassificationResult* classification_result =
|
||||
TfLiteImageClassifierClassifyWithRoi(image_classifier, &frame_buffer,
|
||||
&bounding_box, &error);
|
||||
|
||||
ImageDataFree(&image_data);
|
||||
|
||||
ASSERT_NE(classification_result, nullptr);
|
||||
EXPECT_GE(classification_result->size, 1);
|
||||
EXPECT_NE(classification_result->classifications, nullptr);
|
||||
EXPECT_GE(classification_result->classifications->size, 1);
|
||||
EXPECT_NE(classification_result->classifications->categories, nullptr);
|
||||
EXPECT_EQ(strcmp(classification_result->classifications->categories[0].label,
|
||||
"bagel"),
|
||||
0);
|
||||
EXPECT_GE(classification_result->classifications->categories[0].score, 0.30);
|
||||
|
||||
TfLiteClassificationResultDelete(classification_result);
|
||||
}
|
||||
|
||||
TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBoundsAndError) {
|
||||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
|
||||
LoadImage("burger-224.png"));
|
||||
|
||||
TfLiteFrameBuffer frame_buffer = {
|
||||
.format = kRGB,
|
||||
.orientation = kTopLeft,
|
||||
.dimension = {.width = image_data.width, .height = image_data.height},
|
||||
.buffer = image_data.pixel_data};
|
||||
|
||||
TfLiteBoundingBox bounding_box = {
|
||||
.origin_x = 0, .origin_y = 0, .width = 250, .height = 250};
|
||||
TfLiteSupportError* error = nullptr;
|
||||
TfLiteClassificationResult* classification_result =
|
||||
TfLiteImageClassifierClassifyWithRoi(image_classifier, &frame_buffer,
|
||||
&bounding_box, &error);
|
||||
|
||||
ImageDataFree(&image_data);
|
||||
|
||||
EXPECT_EQ(classification_result, nullptr);
|
||||
if (classification_result)
|
||||
TfLiteClassificationResultDelete(classification_result);
|
||||
|
||||
ASSERT_NE(error, nullptr);
|
||||
EXPECT_EQ(error->code, kInvalidArgumentError);
|
||||
EXPECT_NE(error->message, nullptr);
|
||||
EXPECT_THAT(error->message, HasSubstr("Invalid crop coordinates"));
|
||||
|
||||
TfLiteSupportErrorDelete(error);
|
||||
}
|
||||
|
||||
TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
|
||||
SucceedsWithClassNameDenyList) {
|
||||
char* denylisted_label_name = (char*)"cheeseburger";
|
||||
std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
|
||||
kMobileNetQuantizedWithMetadata);
|
||||
|
||||
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
|
||||
options.base_options.model_file.file_path = model_path.data();
|
||||
|
||||
char* label_denylist[12] = {denylisted_label_name};
|
||||
options.classification_options.label_denylist.list = label_denylist;
|
||||
options.classification_options.label_denylist.length = 1;
|
||||
|
||||
TfLiteImageClassifier* image_classifier =
|
||||
TfLiteImageClassifierFromOptions(&options, nullptr);
|
||||
ASSERT_NE(image_classifier, nullptr);
|
||||
|
||||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
|
||||
LoadImage("burger-224.png"));
|
||||
|
||||
TfLiteFrameBuffer frame_buffer = {
|
||||
.format = kRGB,
|
||||
.orientation = kTopLeft,
|
||||
.dimension = {.width = image_data.width, .height = image_data.height},
|
||||
.buffer = image_data.pixel_data};
|
||||
|
||||
TfLiteClassificationResult* classification_result =
|
||||
TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr);
|
||||
|
||||
ImageDataFree(&image_data);
|
||||
if (image_classifier)
|
||||
TfLiteImageClassifierDelete(image_classifier);
|
||||
|
||||
ASSERT_NE(classification_result, nullptr);
|
||||
EXPECT_GE(classification_result->size, 1);
|
||||
EXPECT_NE(classification_result->classifications, nullptr);
|
||||
EXPECT_GE(classification_result->classifications->size, 1);
|
||||
EXPECT_NE(classification_result->classifications->categories, nullptr);
|
||||
EXPECT_NE(strcmp(classification_result->classifications->categories[0].label,
|
||||
denylisted_label_name),
|
||||
0);
|
||||
|
||||
TfLiteClassificationResultDelete(classification_result);
|
||||
}
|
||||
|
||||
TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
|
||||
SucceedsWithClassNameAllowList) {
|
||||
char* allowlisted_label_name = (char*)"cheeseburger";
|
||||
std::string model_path = JoinPath("./" /*test src dir*/, kTestDataDirectory,
|
||||
kMobileNetQuantizedWithMetadata)
|
||||
.data();
|
||||
|
||||
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
|
||||
options.base_options.model_file.file_path = model_path.data();
|
||||
|
||||
char* label_allowlist[12] = {allowlisted_label_name};
|
||||
options.classification_options.label_allowlist.list = label_allowlist;
|
||||
options.classification_options.label_allowlist.length = 1;
|
||||
|
||||
TfLiteImageClassifier* image_classifier =
|
||||
TfLiteImageClassifierFromOptions(&options, nullptr);
|
||||
ASSERT_NE(image_classifier, nullptr);
|
||||
|
||||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData image_data,
|
||||
LoadImage("burger-224.png"));
|
||||
|
||||
TfLiteFrameBuffer frame_buffer = {
|
||||
.format = kRGB,
|
||||
.orientation = kTopLeft,
|
||||
.dimension = {.width = image_data.width, .height = image_data.height},
|
||||
.buffer = image_data.pixel_data};
|
||||
|
||||
TfLiteClassificationResult* classification_result =
|
||||
TfLiteImageClassifierClassify(image_classifier, &frame_buffer, nullptr);
|
||||
|
||||
ImageDataFree(&image_data);
|
||||
if (image_classifier)
|
||||
TfLiteImageClassifierDelete(image_classifier);
|
||||
|
||||
ASSERT_NE(classification_result, nullptr);
|
||||
EXPECT_GE(classification_result->size, 1);
|
||||
EXPECT_NE(classification_result->classifications, nullptr);
|
||||
EXPECT_GE(classification_result->classifications->size, 1);
|
||||
EXPECT_NE(classification_result->classifications->categories, nullptr);
|
||||
EXPECT_EQ(strcmp(classification_result->classifications->categories[0].label,
|
||||
allowlisted_label_name),
|
||||
0);
|
||||
|
||||
TfLiteClassificationResultDelete(classification_result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace vision
|
||||
} // namespace task
|
||||
} // namespace tflite
|
@ -1,5 +1,5 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow_lite_support:users"],
|
||||
default_visibility = ["//tensorflow_lite_support:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
@ -9,17 +9,12 @@ cc_library(
|
||||
"common.cc",
|
||||
],
|
||||
hdrs = ["common.h"],
|
||||
visibility = [
|
||||
"//tensorflow_lite_support:internal",
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "tflite_use_c_api",
|
||||
values = {
|
||||
"copt": "-DTFLITE_USE_C_API",
|
||||
},
|
||||
visibility = ["//tensorflow_lite_support:__subpackages__"],
|
||||
)
|
||||
|
@ -15,7 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow_lite_support/cc/common.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/cord.h" // from @com_google_absl
|
||||
#include "absl/strings/str_cat.h" // from @com_google_absl
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
@ -25,6 +26,8 @@ absl::Status CreateStatusWithPayload(absl::StatusCode canonical_code,
|
||||
TfLiteSupportStatus tfls_code) {
|
||||
// NOTE: Ignores `message` if the canonical code is ok.
|
||||
absl::Status status = absl::Status(canonical_code, message);
|
||||
// NOTE: Does nothing if the canonical code is ok.
|
||||
status.SetPayload(kTfLiteSupportPayload, absl::Cord(absl::StrCat(tfls_code)));
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/status/status.h" // from @com_google_absl
|
||||
#include "absl/strings/string_view.h" // from @com_google_absl
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
@ -53,6 +53,12 @@ enum class TfLiteSupportStatus {
|
||||
kInvalidArgumentError = 2,
|
||||
// Invalid FlatBuffer file or buffer specified.
|
||||
kInvalidFlatBufferError = 3,
|
||||
// Model contains a builtin op that isn't supported by the OpResolver or
|
||||
// delegates.
|
||||
kUnsupportedBuiltinOp = 4,
|
||||
// Model contains a custom op that isn't supported by the OpResolver or
|
||||
// delegates.
|
||||
kUnsupportedCustomOp = 5,
|
||||
|
||||
// File I/O error codes.
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow_lite_support:users"],
|
||||
default_visibility = ["//tensorflow_lite_support:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
@ -8,8 +10,9 @@ cc_library(
|
||||
hdrs = [
|
||||
"statusor.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow_lite_support/cc/port/default:statusor",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
@ -18,31 +21,54 @@ cc_library(
|
||||
hdrs = [
|
||||
"status_macros.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow_lite_support/cc/port/default:status_macros",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "configuration_proto_inc",
|
||||
hdrs = ["configuration_proto_inc.h"],
|
||||
deps = ["@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_cc_proto"],
|
||||
)
|
||||
|
||||
cc_library_with_tflite(
|
||||
name = "tflite_wrapper",
|
||||
hdrs = ["tflite_wrapper.h"],
|
||||
deps = ["//tensorflow_lite_support/cc/port/default:tflite_wrapper"],
|
||||
)
|
||||
|
||||
# This is identical to the rule above, except that it gets built with
|
||||
# '-DTFLITE_USE_C_API'. This rule is used for unit tests that verify things
|
||||
# work correctly when built with TFLITE_USE_C_API defined.
|
||||
cc_library(
|
||||
name = "tflite_wrapper_with_c_api_for_test",
|
||||
name = "integral_types",
|
||||
hdrs = ["integral_types.h"],
|
||||
visibility = ["//tensorflow_lite_support:users"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gtest_main",
|
||||
testonly = 1,
|
||||
hdrs = ["tflite_wrapper.h"],
|
||||
hdrs = [
|
||||
"benchmark.h",
|
||||
"gmock.h",
|
||||
"gtest.h",
|
||||
"status_matchers.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow_lite_support:internal",
|
||||
],
|
||||
deps = [
|
||||
"//intelligence/mobile_acceleration/proto:allowlist_portable_proto",
|
||||
"//intelligence/mobile_acceleration/support_library:tflite_wrapper_with_c_api_for_test",
|
||||
"//tensorflow_lite_support/cc/port/default:status_matchers",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "integral_types",
|
||||
hdrs = ["integral_types.h"],
|
||||
name = "proto2",
|
||||
hdrs = [
|
||||
"proto2.h",
|
||||
],
|
||||
deps = [
|
||||
"@com_google_protobuf//:protobuf",
|
||||
],
|
||||
)
|
||||
|
@ -1,30 +1,22 @@
|
||||
""".bzl file for TFLite Support open source build configs."""
|
||||
|
||||
load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library")
|
||||
|
||||
def provided_args(**kwargs):
|
||||
"""Returns the keyword arguments omitting None arguments."""
|
||||
return {k: v for k, v in kwargs.items() if v != None}
|
||||
|
||||
def support_cc_proto_library(name, srcs, visibility = None, deps = [], cc_deps = [], testonly = 0):
|
||||
"""Generate cc_proto_library for TFLite Support open source version.
|
||||
def support_cc_proto_library(name, deps = [], visibility = None):
|
||||
"""Generates cc_proto_library for TFLite Support open source version.
|
||||
|
||||
Args:
|
||||
name: the name of the cc_proto_library.
|
||||
srcs: the .proto files of the cc_proto_library for Bazel use.
|
||||
deps: a list of dependency labels for Bazel use; must be proto_library.
|
||||
visibility: visibility of this target.
|
||||
deps: a list of dependency labels for Bazel use; must be cc_proto_library.
|
||||
testonly: test only proto or not.
|
||||
"""
|
||||
_ignore = [deps]
|
||||
cc_proto_library(**provided_args(
|
||||
|
||||
# Verified in the external path.
|
||||
# buildifier: disable=native-cc-proto
|
||||
native.cc_proto_library(**provided_args(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
visibility = visibility,
|
||||
deps = cc_deps,
|
||||
testonly = testonly,
|
||||
cc_libs = ["@com_google_protobuf//:protobuf"],
|
||||
protoc = "@com_google_protobuf//:protoc",
|
||||
default_runtime = "@com_google_protobuf//:protobuf",
|
||||
alwayslink = 1,
|
||||
deps = deps,
|
||||
))
|
||||
|
21
third_party/tflite_support/src/tensorflow_lite_support/cc/port/configuration_proto_inc.h
vendored
Normal file
21
third_party/tflite_support/src/tensorflow_lite_support/cc/port/configuration_proto_inc.h
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_PORT_CONFIGURATION_PROTO_INC_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_PORT_CONFIGURATION_PROTO_INC_H_
|
||||
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h"
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_PORT_CONFIGURATION_PROTO_INC_H_
|
@ -1,29 +1,11 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow_lite_support/cc/port:__pkg__",
|
||||
"//tensorflow_lite_support/cc/test:__pkg__",
|
||||
"//tensorflow_lite_support/cc/port:__subpackages__",
|
||||
"//tensorflow_lite_support/cc/test:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "statusor",
|
||||
srcs = ["statusor.cc"],
|
||||
hdrs = [
|
||||
"statusor.h",
|
||||
"statusor_internals.h",
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/meta:type_traits",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"@com_google_absl//absl/utility",
|
||||
"@com_google_glog//:glog",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "status_macros",
|
||||
hdrs = [
|
||||
@ -35,6 +17,15 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "status_matchers",
|
||||
testonly = 1,
|
||||
hdrs = ["status_matchers.h"],
|
||||
deps = [
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tflite_wrapper",
|
||||
srcs = ["tflite_wrapper.cc"],
|
||||
@ -42,9 +33,32 @@ cc_library(
|
||||
"tflite_wrapper.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow_lite_support/cc/port:status_macros",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@flatbuffers",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite:minimal_logging",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:nnapi_plugin",
|
||||
"@org_tensorflow//tensorflow/lite/delegates:interpreter_utils",
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_cc_proto",
|
||||
],
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:delegate_registry",
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:flatbuffer_to_proto",
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:proto_to_flatbuffer",
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/mini_benchmark",
|
||||
"//tensorflow_lite_support/cc/port:status_macros",
|
||||
] + select({
|
||||
# We only intend to use TFLite mini-benchmark on arm-based Andorid and x86_64 Linux.
|
||||
"@org_tensorflow//tensorflow:android_arm": [
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/mini_benchmark:mini_benchmark_implementation",
|
||||
],
|
||||
"@org_tensorflow//tensorflow:android_arm64": [
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/mini_benchmark:mini_benchmark_implementation",
|
||||
],
|
||||
"@org_tensorflow//tensorflow:linux_x86_64": [
|
||||
"@org_tensorflow//tensorflow/lite/experimental/acceleration/mini_benchmark:mini_benchmark_implementation",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_
|
||||
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/base/optimization.h" // from @com_google_absl
|
||||
#include "absl/status/status.h" // from @com_google_absl
|
||||
|
||||
// Evaluates an expression that produces a `absl::Status`. If the status is not
|
||||
// ok, returns it from the current function.
|
||||
|
@ -1,67 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// This file is forked from absl.
|
||||
|
||||
#include "tensorflow_lite_support/cc/port/default/statusor.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "base/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
||||
BadStatusOrAccess::BadStatusOrAccess(absl::Status status)
|
||||
: status_(std::move(status)) {}
|
||||
|
||||
BadStatusOrAccess::~BadStatusOrAccess() = default;
|
||||
|
||||
const char* BadStatusOrAccess::what() const noexcept {
|
||||
return "Bad StatusOr access";
|
||||
}
|
||||
|
||||
const absl::Status& BadStatusOrAccess::status() const {
|
||||
return status_;
|
||||
}
|
||||
|
||||
namespace internal_statusor {
|
||||
|
||||
void Helper::HandleInvalidStatusCtorArg(absl::Status* status) {
|
||||
const char* kMessage =
|
||||
"An OK status is not a valid constructor argument to StatusOr<T>";
|
||||
LOG(DFATAL) << kMessage;
|
||||
// In optimized builds, we will fall back to ::util::error::INTERNAL.
|
||||
*status = absl::InternalError(kMessage);
|
||||
}
|
||||
|
||||
void Helper::Crash(const absl::Status& status) {
|
||||
LOG(FATAL) << "Attempting to fetch value instead of handling error "
|
||||
<< status;
|
||||
_Exit(1);
|
||||
}
|
||||
|
||||
void ThrowBadStatusOrAccess(absl::Status status) {
|
||||
#ifdef ABSL_HAVE_EXCEPTIONS
|
||||
throw BadStatusOrAccess(std::move(status));
|
||||
#else
|
||||
LOG(FATAL) << "Attempting to fetch value instead of handling error "
|
||||
<< status;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace internal_statusor
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -1,584 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// This file is forked from absl.
|
||||
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_
|
||||
|
||||
#include <exception>
|
||||
#include <initializer_list>
|
||||
#include <new>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/meta/type_traits.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "absl/utility/utility.h"
|
||||
#include "tensorflow_lite_support/cc/port/default/statusor_internals.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
||||
#ifndef SWIG
|
||||
class BadStatusOrAccess : public std::exception {
|
||||
public:
|
||||
explicit BadStatusOrAccess(absl::Status status);
|
||||
~BadStatusOrAccess() override;
|
||||
const char* what() const noexcept override;
|
||||
const absl::Status& status() const;
|
||||
|
||||
private:
|
||||
absl::Status status_;
|
||||
};
|
||||
#endif // !SWIG
|
||||
|
||||
// Returned StatusOr objects may not be ignored.
|
||||
// Note: Disabled for SWIG as it doesn't parse attributes correctly. Codesearch
|
||||
// doesn't handle ifdefs as part of a class definitions (b/6995610), so we use a
|
||||
// forward declaration.
|
||||
#ifndef SWIG
|
||||
template <typename T>
|
||||
class ABSL_MUST_USE_RESULT StatusOr;
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
class StatusOr : private internal_statusor::StatusOrData<T>,
|
||||
private internal_statusor::CopyCtorBase<T>,
|
||||
private internal_statusor::MoveCtorBase<T>,
|
||||
private internal_statusor::CopyAssignBase<T>,
|
||||
private internal_statusor::MoveAssignBase<T> {
|
||||
template <typename U>
|
||||
friend class StatusOr;
|
||||
|
||||
typedef internal_statusor::StatusOrData<T> Base;
|
||||
|
||||
public:
|
||||
typedef T value_type;
|
||||
|
||||
// Constructs a new StatusOr with Status::UNKNOWN status. This is marked
|
||||
// 'explicit' to try to catch cases like 'return {};', where people think
|
||||
// tflite::support::StatusOr<std::vector<int>> will be initialized with an
|
||||
// empty vector, instead of a Status::UNKNOWN status.
|
||||
explicit StatusOr();
|
||||
|
||||
// StatusOr<T> is copy constructible if T is copy constructible.
|
||||
StatusOr(const StatusOr&) = default;
|
||||
// StatusOr<T> is copy assignable if T is copy constructible and copy
|
||||
// assignable.
|
||||
StatusOr& operator=(const StatusOr&) = default;
|
||||
|
||||
#ifndef SWIG
|
||||
|
||||
// StatusOr<T> is move constructible if T is move constructible.
|
||||
StatusOr(StatusOr&&) = default;
|
||||
// StatusOr<T> is moveAssignable if T is move constructible and move
|
||||
// assignable.
|
||||
StatusOr& operator=(StatusOr&&) = default;
|
||||
|
||||
// Converting constructors from StatusOr<U>, when T is constructible from U.
|
||||
// To avoid ambiguity, they are disabled if T is also constructible from
|
||||
// StatusOr<U>. Explicit iff the corresponding construction of T from U is
|
||||
// explicit.
|
||||
template <
|
||||
typename U,
|
||||
absl::enable_if_t<
|
||||
absl::conjunction<
|
||||
absl::negation<std::is_same<T, U>>,
|
||||
std::is_constructible<T, const U&>,
|
||||
std::is_convertible<const U&, T>,
|
||||
absl::negation<
|
||||
internal_statusor::
|
||||
IsConstructibleOrConvertibleFromStatusOr<T, U>>>::value,
|
||||
int> = 0>
|
||||
StatusOr(const StatusOr<U>& other) // NOLINT
|
||||
: Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
|
||||
template <
|
||||
typename U,
|
||||
absl::enable_if_t<
|
||||
absl::conjunction<
|
||||
absl::negation<std::is_same<T, U>>,
|
||||
std::is_constructible<T, const U&>,
|
||||
absl::negation<std::is_convertible<const U&, T>>,
|
||||
absl::negation<
|
||||
internal_statusor::
|
||||
IsConstructibleOrConvertibleFromStatusOr<T, U>>>::value,
|
||||
int> = 0>
|
||||
explicit StatusOr(const StatusOr<U>& other)
|
||||
: Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
|
||||
|
||||
template <
|
||||
typename U,
|
||||
absl::enable_if_t<
|
||||
absl::conjunction<
|
||||
absl::negation<std::is_same<T, U>>,
|
||||
std::is_constructible<T, U&&>,
|
||||
std::is_convertible<U&&, T>,
|
||||
absl::negation<
|
||||
internal_statusor::
|
||||
IsConstructibleOrConvertibleFromStatusOr<T, U>>>::value,
|
||||
int> = 0>
|
||||
StatusOr(StatusOr<U>&& other) // NOLINT
|
||||
: Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
|
||||
template <
|
||||
typename U,
|
||||
absl::enable_if_t<
|
||||
absl::conjunction<
|
||||
absl::negation<std::is_same<T, U>>,
|
||||
std::is_constructible<T, U&&>,
|
||||
absl::negation<std::is_convertible<U&&, T>>,
|
||||
absl::negation<
|
||||
internal_statusor::
|
||||
IsConstructibleOrConvertibleFromStatusOr<T, U>>>::value,
|
||||
int> = 0>
|
||||
explicit StatusOr(StatusOr<U>&& other)
|
||||
: Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
|
||||
|
||||
// Conversion copy/move assignment operator, T must be constructible and
|
||||
// assignable from U. Only enable if T cannot be directly assigned from
|
||||
// StatusOr<U>.
|
||||
template <
|
||||
typename U,
|
||||
absl::enable_if_t<
|
||||
absl::conjunction<
|
||||
absl::negation<std::is_same<T, U>>,
|
||||
std::is_constructible<T, const U&>,
|
||||
std::is_assignable<T, const U&>,
|
||||
absl::negation<
|
||||
internal_statusor::
|
||||
IsConstructibleOrConvertibleOrAssignableFromStatusOr<
|
||||
T,
|
||||
U>>>::value,
|
||||
int> = 0>
|
||||
StatusOr& operator=(const StatusOr<U>& other) {
|
||||
this->Assign(other);
|
||||
return *this;
|
||||
}
|
||||
template <
|
||||
typename U,
|
||||
absl::enable_if_t<
|
||||
absl::conjunction<
|
||||
absl::negation<std::is_same<T, U>>,
|
||||
std::is_constructible<T, U&&>,
|
||||
std::is_assignable<T, U&&>,
|
||||
absl::negation<
|
||||
internal_statusor::
|
||||
IsConstructibleOrConvertibleOrAssignableFromStatusOr<
|
||||
T,
|
||||
U>>>::value,
|
||||
int> = 0>
|
||||
StatusOr& operator=(StatusOr<U>&& other) {
|
||||
this->Assign(std::move(other));
|
||||
return *this;
|
||||
}
|
||||
|
||||
#endif // SWIG
|
||||
|
||||
// Constructs a new StatusOr with the given value. After calling this
|
||||
// constructor, this->ok() will be true and the contained value may be
|
||||
// retrieved with value(), operator*(), or operator->().
|
||||
//
|
||||
// NOTE: Not explicit - we want to use StatusOr<T> as a return type
|
||||
// so it is convenient and sensible to be able to do 'return T()'
|
||||
// when the return type is StatusOr<T>.
|
||||
//
|
||||
// REQUIRES: T is copy constructible.
|
||||
// TODO(b/113125838): Replace this constructor with a direct-initialization
|
||||
// constructor.
|
||||
StatusOr(const T& value);
|
||||
|
||||
// Constructs a new StatusOr with the given non-ok status. After calling this
|
||||
// constructor, this->ok() will be false and calls to value() will CHECK-fail.
|
||||
//
|
||||
// NOTE: Not explicit - we want to use StatusOr<T> as a return
|
||||
// value, so it is convenient and sensible to be able to do 'return
|
||||
// Status()' when the return type is StatusOr<T>.
|
||||
//
|
||||
// REQUIRES: !status.ok(). This requirement is DCHECKed.
|
||||
// In optimized builds, passing util::OkStatus() here will have the effect
|
||||
// of passing util::error::INTERNAL as a fallback.
|
||||
StatusOr(const absl::Status& status);
|
||||
StatusOr& operator=(const absl::Status& status);
|
||||
|
||||
#ifndef SWIG
|
||||
// Perfect-forwarding value assignment operator.
|
||||
// If `*this` contains a `T` value before the call, the contained value is
|
||||
// assigned from `std::forward<U>(v)`; Otherwise, it is directly-initialized
|
||||
// from `std::forward<U>(v)`.
|
||||
// This function does not participate in overload unless:
|
||||
// 1. `std::is_constructible_v<T, U>` is true,
|
||||
// 2. `std::is_assignable_v<T&, U>` is true.
|
||||
// 3. `std::is_same_v<StatusOr<T>, std::remove_cvref_t<U>>` is false.
|
||||
// 4. Assigning `U` to `T` is not ambiguous:
|
||||
// If `U` is `StatusOr<V>` and `T` is constructible and assignable from
|
||||
// both `StatusOr<V>` and `V`, the assignment is considered bug-prone and
|
||||
// ambiguous thus will fail to compile. For example:
|
||||
// StatusOr<bool> s1 = true; // s1.ok() && *s1 == true
|
||||
// StatusOr<bool> s2 = false; // s2.ok() && *s2 == false
|
||||
// s1 = s2; // ambiguous, `s1 = *s2` or `s1 = bool(s2)`?
|
||||
template <
|
||||
typename U = T,
|
||||
typename = typename std::enable_if<absl::conjunction<
|
||||
std::is_constructible<T, U&&>,
|
||||
std::is_assignable<T&, U&&>,
|
||||
internal_statusor::IsForwardingAssignmentValid<T, U&&>>::value>::type>
|
||||
StatusOr& operator=(U&& v) {
|
||||
this->Assign(std::forward<U>(v));
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Similar to the `const T&` overload.
|
||||
//
|
||||
// REQUIRES: T is move constructible.
|
||||
StatusOr(T&& value);
|
||||
|
||||
// RValue versions of the operations declared above.
|
||||
StatusOr(absl::Status&& status);
|
||||
StatusOr& operator=(absl::Status&& status);
|
||||
|
||||
// Constructs the inner value T in-place using the provided args, using the
|
||||
// T(args...) constructor.
|
||||
template <typename... Args>
|
||||
explicit StatusOr(absl::in_place_t, Args&&... args);
|
||||
template <typename U, typename... Args>
|
||||
explicit StatusOr(absl::in_place_t,
|
||||
std::initializer_list<U> ilist,
|
||||
Args&&... args);
|
||||
|
||||
// Constructs the inner value T in-place using the provided args, using the
|
||||
// T(U) (direct-initialization) constructor. Only valid if T can be
|
||||
// constructed from a U. Can accept move or copy constructors. Explicit if
|
||||
// U is not convertible to T. To avoid ambiguity, this is disabled if U is
|
||||
// a StatusOr<J>, where J is convertible to T.
|
||||
// Style waiver for implicit conversion granted in cl/209187539.
|
||||
template <typename U = T,
|
||||
absl::enable_if_t<
|
||||
absl::conjunction<
|
||||
internal_statusor::IsDirectInitializationValid<T, U&&>,
|
||||
std::is_constructible<T, U&&>,
|
||||
std::is_convertible<U&&, T>>::value,
|
||||
int> = 0>
|
||||
StatusOr(U&& u) // NOLINT
|
||||
: StatusOr(absl::in_place, std::forward<U>(u)) {}
|
||||
|
||||
template <typename U = T,
|
||||
absl::enable_if_t<
|
||||
absl::conjunction<
|
||||
internal_statusor::IsDirectInitializationValid<T, U&&>,
|
||||
std::is_constructible<T, U&&>,
|
||||
absl::negation<std::is_convertible<U&&, T>>>::value,
|
||||
int> = 0>
|
||||
explicit StatusOr(U&& u) // NOLINT
|
||||
: StatusOr(absl::in_place, std::forward<U>(u)) {}
|
||||
|
||||
#endif // SWIG
|
||||
|
||||
// Returns this->status().ok()
|
||||
ABSL_MUST_USE_RESULT bool ok() const { return this->status_.ok(); }
|
||||
|
||||
// Returns a reference to our status. If this contains a T, then
|
||||
// returns util::OkStatus().
|
||||
#ifdef SWIG
|
||||
const ::util::Status& status() const;
|
||||
#else // SWIG
|
||||
const absl::Status& status() const&;
|
||||
absl::Status status() &&;
|
||||
#endif // SWIG
|
||||
|
||||
// Returns a reference to the held value if `this->ok()`. Otherwise, throws
|
||||
// `absl::BadStatusOrAccess` if exception is enabled, or `LOG(FATAL)` if
|
||||
// exception is disabled.
|
||||
// If you have already checked the status using `this->ok()` or
|
||||
// `operator bool()`, you probably want to use `operator*()` or `operator->()`
|
||||
// to access the value instead of `value`.
|
||||
// Note: for value types that are cheap to copy, prefer simple code:
|
||||
//
|
||||
// T value = statusor.value();
|
||||
//
|
||||
// Otherwise, if the value type is expensive to copy, but can be left
|
||||
// in the StatusOr, simply assign to a reference:
|
||||
//
|
||||
// T& value = statusor.value(); // or `const T&`
|
||||
//
|
||||
// Otherwise, if the value type supports an efficient move, it can be
|
||||
// used as follows:
|
||||
//
|
||||
// T value = std::move(statusor).value();
|
||||
//
|
||||
// The `std::move` on statusor instead of on the whole expression enables
|
||||
// warnings about possible uses of the statusor object after the move.
|
||||
#ifdef SWIG
|
||||
const T& value() const;
|
||||
#else // SWIG
|
||||
const T& value() const&;
|
||||
T& value() &;
|
||||
const T&& value() const&&;
|
||||
T&& value() &&;
|
||||
#endif // SWIG
|
||||
|
||||
#ifndef SWIG
|
||||
// Returns a reference to the current value.
|
||||
//
|
||||
// REQUIRES: this->ok() == true, otherwise the behavior is undefined.
|
||||
//
|
||||
// Use this->ok() or `operator bool()` to verify that there is a current
|
||||
// value. Alternatively, see value() for a similar API that guarantees
|
||||
// CHECK-failing if there is no current value.
|
||||
const T& operator*() const&;
|
||||
T& operator*() &;
|
||||
const T&& operator*() const&&;
|
||||
T&& operator*() &&;
|
||||
#endif // SWIG
|
||||
|
||||
#ifndef SWIG
|
||||
// Returns a pointer to the current value.
|
||||
//
|
||||
// REQUIRES: this->ok() == true, otherwise the behavior is undefined.
|
||||
//
|
||||
// Use this->ok() or `operator bool()` to verify that there is a current
|
||||
// value.
|
||||
const T* operator->() const;
|
||||
T* operator->();
|
||||
#endif // SWIG
|
||||
|
||||
#ifndef SWIG
|
||||
// Returns a copy of the current value if this->ok() == true. Otherwise
|
||||
// returns a default value.
|
||||
template <typename U>
|
||||
T value_or(U&& default_value) const&;
|
||||
template <typename U>
|
||||
T value_or(U&& default_value) &&;
|
||||
#endif // SWIG
|
||||
|
||||
// Ignores any errors. This method does nothing except potentially suppress
|
||||
// complaints from any tools that are checking that errors are not dropped on
|
||||
// the floor.
|
||||
void IgnoreError() const;
|
||||
|
||||
#ifndef SWIG
|
||||
// Reconstructs the inner value T in-place using the provided args, using the
|
||||
// T(args...) constructor. Returns reference to the reconstructed `T`.
|
||||
template <typename... Args>
|
||||
T& emplace(Args&&... args) {
|
||||
if (ok()) {
|
||||
this->Clear();
|
||||
this->MakeValue(std::forward<Args>(args)...);
|
||||
} else {
|
||||
this->MakeValue(std::forward<Args>(args)...);
|
||||
this->status_ = absl::OkStatus();
|
||||
}
|
||||
return this->data_;
|
||||
}
|
||||
|
||||
template <
|
||||
typename U,
|
||||
typename... Args,
|
||||
absl::enable_if_t<
|
||||
std::is_constructible<T, std::initializer_list<U>&, Args&&...>::value,
|
||||
int> = 0>
|
||||
T& emplace(std::initializer_list<U> ilist, Args&&... args) {
|
||||
if (ok()) {
|
||||
this->Clear();
|
||||
this->MakeValue(ilist, std::forward<Args>(args)...);
|
||||
} else {
|
||||
this->MakeValue(ilist, std::forward<Args>(args)...);
|
||||
this->status_ = absl::OkStatus();
|
||||
}
|
||||
return this->data_;
|
||||
}
|
||||
#endif // SWIG
|
||||
|
||||
private:
|
||||
#ifndef SWIG
|
||||
using internal_statusor::StatusOrData<T>::Assign;
|
||||
template <typename U>
|
||||
void Assign(const StatusOr<U>& other);
|
||||
template <typename U>
|
||||
void Assign(StatusOr<U>&& other);
|
||||
#endif // SWIG
|
||||
};
|
||||
|
||||
#ifndef SWIG
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Implementation details for StatusOr<T>
|
||||
|
||||
template <typename T>
|
||||
tflite::support::StatusOr<T>::StatusOr()
|
||||
: Base(absl::Status(absl::StatusCode::kUnknown, "")) {}
|
||||
|
||||
template <typename T>
|
||||
tflite::support::StatusOr<T>::StatusOr(const T& value) : Base(value) {}
|
||||
|
||||
template <typename T>
|
||||
tflite::support::StatusOr<T>::StatusOr(const absl::Status& status)
|
||||
: Base(status) {}
|
||||
|
||||
template <typename T>
|
||||
tflite::support::StatusOr<T>& StatusOr<T>::operator=(
|
||||
const absl::Status& status) {
|
||||
this->Assign(status);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
tflite::support::StatusOr<T>::StatusOr(T&& value) : Base(std::move(value)) {}
|
||||
|
||||
template <typename T>
|
||||
tflite::support::StatusOr<T>::StatusOr(absl::Status&& status)
|
||||
: Base(std::move(status)) {}
|
||||
|
||||
template <typename T>
|
||||
tflite::support::StatusOr<T>& StatusOr<T>::operator=(absl::Status&& status) {
|
||||
this->Assign(std::move(status));
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename U>
|
||||
inline void StatusOr<T>::Assign(const StatusOr<U>& other) {
|
||||
if (other.ok()) {
|
||||
this->Assign(other.value());
|
||||
} else {
|
||||
this->Assign(other.status());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename U>
|
||||
inline void StatusOr<T>::Assign(StatusOr<U>&& other) {
|
||||
if (other.ok()) {
|
||||
this->Assign(std::move(other).value());
|
||||
} else {
|
||||
this->Assign(std::move(other).status());
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
template <typename... Args>
|
||||
tflite::support::StatusOr<T>::StatusOr(absl::in_place_t, Args&&... args)
|
||||
: Base(absl::in_place, std::forward<Args>(args)...) {}
|
||||
|
||||
template <typename T>
|
||||
template <typename U, typename... Args>
|
||||
tflite::support::StatusOr<T>::StatusOr(absl::in_place_t,
|
||||
std::initializer_list<U> ilist,
|
||||
Args&&... args)
|
||||
: Base(absl::in_place, ilist, std::forward<Args>(args)...) {}
|
||||
|
||||
template <typename T>
|
||||
const absl::Status& StatusOr<T>::status() const& {
|
||||
return this->status_;
|
||||
}
|
||||
template <typename T>
|
||||
absl::Status StatusOr<T>::status() && {
|
||||
return ok() ? absl::OkStatus() : std::move(this->status_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T& StatusOr<T>::value() const& {
|
||||
if (!this->ok())
|
||||
internal_statusor::ThrowBadStatusOrAccess(this->status_);
|
||||
return this->data_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& StatusOr<T>::value() & {
|
||||
if (!this->ok())
|
||||
internal_statusor::ThrowBadStatusOrAccess(this->status_);
|
||||
return this->data_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T&& StatusOr<T>::value() const&& {
|
||||
if (!this->ok()) {
|
||||
internal_statusor::ThrowBadStatusOrAccess(std::move(this->status_));
|
||||
}
|
||||
return std::move(this->data_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T&& StatusOr<T>::value() && {
|
||||
if (!this->ok()) {
|
||||
internal_statusor::ThrowBadStatusOrAccess(std::move(this->status_));
|
||||
}
|
||||
return std::move(this->data_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T& StatusOr<T>::operator*() const& {
|
||||
this->EnsureOk();
|
||||
return this->data_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& StatusOr<T>::operator*() & {
|
||||
this->EnsureOk();
|
||||
return this->data_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T&& StatusOr<T>::operator*() const&& {
|
||||
this->EnsureOk();
|
||||
return std::move(this->data_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T&& StatusOr<T>::operator*() && {
|
||||
this->EnsureOk();
|
||||
return std::move(this->data_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* StatusOr<T>::operator->() const {
|
||||
this->EnsureOk();
|
||||
return &this->data_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* StatusOr<T>::operator->() {
|
||||
this->EnsureOk();
|
||||
return &this->data_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename U>
|
||||
T StatusOr<T>::value_or(U&& default_value) const& {
|
||||
if (ok()) {
|
||||
return this->data_;
|
||||
}
|
||||
return std::forward<U>(default_value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename U>
|
||||
T StatusOr<T>::value_or(U&& default_value) && {
|
||||
if (ok()) {
|
||||
return std::move(this->data_);
|
||||
}
|
||||
return std::forward<U>(default_value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void StatusOr<T>::IgnoreError() const {
|
||||
// no-op
|
||||
}
|
||||
|
||||
#endif // SWIG
|
||||
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_
|
6
third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
vendored
6
third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/statusor_internals.h
vendored
@ -20,9 +20,9 @@ limitations under the License.
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/meta/type_traits.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/utility/utility.h"
|
||||
#include "absl/meta/type_traits.h" // from @com_google_absl
|
||||
#include "absl/status/status.h" // from @com_google_absl
|
||||
#include "absl/utility/utility.h" // from @com_google_absl
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
323
third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
vendored
323
third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
vendored
@ -15,45 +15,332 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h"
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/status.h" // from @com_google_absl
|
||||
#include "absl/strings/str_format.h" // from @com_google_absl
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/interpreter_utils.h"
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/flatbuffer_to_proto.h"
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/proto_to_flatbuffer.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow_lite_support/cc/port/status_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
||||
namespace {
|
||||
using tflite::delegates::DelegatePluginRegistry;
|
||||
using tflite::delegates::InterpreterUtils;
|
||||
using tflite::proto::ComputeSettings;
|
||||
using tflite::proto::Delegate;
|
||||
} // namespace
|
||||
|
||||
/* static */
|
||||
absl::Status TfLiteInterpreterWrapper::SanityCheckComputeSettings(
|
||||
const ComputeSettings& compute_settings) {
|
||||
Delegate delegate = compute_settings.tflite_settings().delegate();
|
||||
if (delegate != Delegate::NONE && delegate != Delegate::GPU &&
|
||||
delegate != Delegate::HEXAGON && delegate != Delegate::NNAPI &&
|
||||
delegate != Delegate::XNNPACK && delegate != Delegate::EDGETPU_CORAL) {
|
||||
return absl::UnimplementedError(absl::StrFormat(
|
||||
"Using delegate '%s' is not supported.", Delegate_Name(delegate)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
TfLiteInterpreterWrapper::TfLiteInterpreterWrapper(
|
||||
const std::string& default_model_namespace,
|
||||
const std::string& default_model_id)
|
||||
: delegate_(nullptr, nullptr),
|
||||
got_error_do_not_delegate_anymore_(false),
|
||||
default_model_namespace_(default_model_namespace),
|
||||
default_model_id_(default_model_id),
|
||||
mini_benchmark_(nullptr) {}
|
||||
|
||||
std::string TfLiteInterpreterWrapper::ModelNamespace() {
|
||||
const auto& ns_from_acceleration =
|
||||
compute_settings_.model_namespace_for_statistics();
|
||||
return ns_from_acceleration.empty() ? default_model_namespace_
|
||||
: ns_from_acceleration;
|
||||
}
|
||||
|
||||
std::string TfLiteInterpreterWrapper::ModelID() {
|
||||
const auto& id_from_acceleration =
|
||||
compute_settings_.model_identifier_for_statistics();
|
||||
return id_from_acceleration.empty() ? default_model_id_
|
||||
: id_from_acceleration;
|
||||
}
|
||||
|
||||
// This is the deprecated overload that doesn't take an
|
||||
// InterpreterCreationResources parameter.
|
||||
absl::Status TfLiteInterpreterWrapper::InitializeWithFallback(
|
||||
std::function<absl::Status(std::unique_ptr<tflite::Interpreter>*)>
|
||||
interpreter_initializer,
|
||||
const tflite::proto::ComputeSettings& compute_settings) {
|
||||
if (compute_settings.has_preference() ||
|
||||
compute_settings.has_tflite_settings()) {
|
||||
return absl::UnimplementedError(
|
||||
"Acceleration via ComputeSettings is not supported yet.");
|
||||
const ComputeSettings& compute_settings) {
|
||||
return InitializeWithFallback(
|
||||
[interpreter_initializer](
|
||||
const InterpreterCreationResources& resources,
|
||||
std::unique_ptr<tflite::Interpreter>* interpreter_out)
|
||||
-> absl::Status {
|
||||
RETURN_IF_ERROR(interpreter_initializer(interpreter_out));
|
||||
if (*interpreter_out != nullptr &&
|
||||
resources.optional_delegate != nullptr) {
|
||||
TfLiteStatus status =
|
||||
(*interpreter_out)
|
||||
->ModifyGraphWithDelegate(resources.optional_delegate);
|
||||
if (status != kTfLiteOk) {
|
||||
*interpreter_out = nullptr;
|
||||
RETURN_IF_ERROR(
|
||||
absl::InvalidArgumentError("Applying delegate failed"));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
},
|
||||
compute_settings);
|
||||
}
|
||||
|
||||
absl::Status TfLiteInterpreterWrapper::InitializeWithFallback(
|
||||
std::function<absl::Status(const InterpreterCreationResources&,
|
||||
std::unique_ptr<tflite::Interpreter>*)>
|
||||
interpreter_initializer,
|
||||
const ComputeSettings& compute_settings) {
|
||||
// Store interpreter initializer if not already here.
|
||||
if (interpreter_initializer_) {
|
||||
return absl::FailedPreconditionError(
|
||||
"InitializeWithFallback already called.");
|
||||
}
|
||||
RETURN_IF_ERROR(interpreter_initializer(&interpreter_));
|
||||
return interpreter_->AllocateTensors() != kTfLiteOk
|
||||
? absl::InternalError(
|
||||
"TFLite interpreter: AllocateTensors() failed.")
|
||||
: absl::OkStatus();
|
||||
interpreter_initializer_ = std::move(interpreter_initializer);
|
||||
|
||||
// Sanity check and copy ComputeSettings.
|
||||
RETURN_IF_ERROR(SanityCheckComputeSettings(compute_settings));
|
||||
compute_settings_ = compute_settings;
|
||||
if (compute_settings_.has_settings_to_test_locally()) {
|
||||
flatbuffers::FlatBufferBuilder mini_benchmark_settings_fbb;
|
||||
const auto* mini_benchmark_settings =
|
||||
tflite::ConvertFromProto(compute_settings_.settings_to_test_locally(),
|
||||
&mini_benchmark_settings_fbb);
|
||||
mini_benchmark_ = tflite::acceleration::CreateMiniBenchmark(
|
||||
*mini_benchmark_settings, ModelNamespace(), ModelID());
|
||||
const tflite::ComputeSettingsT from_minibenchmark =
|
||||
mini_benchmark_->GetBestAcceleration();
|
||||
if (from_minibenchmark.tflite_settings != nullptr) {
|
||||
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO, "Using mini benchmark results\n");
|
||||
compute_settings_ = tflite::ConvertFromFlatbuffer(
|
||||
from_minibenchmark, /*skip_mini_benchmark_settings=*/true);
|
||||
}
|
||||
// Trigger mini benchmark if it hasn't already run. Vast majority of cases
|
||||
// should not actually do anything, since first runs are rare.
|
||||
mini_benchmark_->TriggerMiniBenchmark();
|
||||
mini_benchmark_->MarkAndGetEventsToLog();
|
||||
}
|
||||
|
||||
// Initialize fallback behavior.
|
||||
fallback_on_compilation_error_ =
|
||||
compute_settings_.tflite_settings()
|
||||
.fallback_settings()
|
||||
.allow_automatic_fallback_on_compilation_error() ||
|
||||
// Deprecated, keep supporting for backward compatibility.
|
||||
compute_settings_.tflite_settings()
|
||||
.nnapi_settings()
|
||||
.fallback_settings()
|
||||
.allow_automatic_fallback_on_compilation_error();
|
||||
fallback_on_execution_error_ =
|
||||
compute_settings_.tflite_settings()
|
||||
.fallback_settings()
|
||||
.allow_automatic_fallback_on_execution_error() ||
|
||||
// Deprecated, keep supporting for backward compatibility.
|
||||
compute_settings_.tflite_settings()
|
||||
.nnapi_settings()
|
||||
.fallback_settings()
|
||||
.allow_automatic_fallback_on_execution_error();
|
||||
|
||||
return InitializeWithFallbackAndResize();
|
||||
}
|
||||
|
||||
absl::Status TfLiteInterpreterWrapper::AllocateTensors() {
|
||||
if (interpreter_->AllocateTensors() != kTfLiteOk) {
|
||||
return absl::InternalError("AllocateTensors() failed.");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// TODO(b/173406463): the `resize` parameter is going to be used by
|
||||
// ResizeAndAllocateTensors functions, coming soon.
|
||||
absl::Status TfLiteInterpreterWrapper::InitializeWithFallbackAndResize(
|
||||
std::function<absl::Status(Interpreter*)> resize) {
|
||||
InterpreterCreationResources resources{};
|
||||
if (got_error_do_not_delegate_anymore_ ||
|
||||
compute_settings_.tflite_settings().delegate() == Delegate::NONE) {
|
||||
delegate_.reset(nullptr);
|
||||
} else {
|
||||
// Initialize delegate and add it to 'resources'.
|
||||
RETURN_IF_ERROR(InitializeDelegate());
|
||||
resources.optional_delegate = delegate_.get();
|
||||
}
|
||||
|
||||
absl::Status status = interpreter_initializer_(resources, &interpreter_);
|
||||
if (resources.optional_delegate == nullptr) {
|
||||
RETURN_IF_ERROR(status);
|
||||
}
|
||||
if (resources.optional_delegate != nullptr && !status.ok()) {
|
||||
// Any error when constructing the interpreter is assumed to be a delegate
|
||||
// compilation error. If a delegate compilation error occurs, stop
|
||||
// delegation from happening in the future.
|
||||
got_error_do_not_delegate_anymore_ = true;
|
||||
delegate_.reset(nullptr);
|
||||
if (fallback_on_compilation_error_) {
|
||||
InterpreterCreationResources fallback_resources{};
|
||||
fallback_resources.optional_delegate = nullptr;
|
||||
RETURN_IF_ERROR(
|
||||
interpreter_initializer_(fallback_resources, &interpreter_));
|
||||
} else {
|
||||
// If instructed not to fallback, return error.
|
||||
return absl::InternalError(absl::StrFormat(
|
||||
"ModifyGraphWithDelegate() failed for delegate '%s'.",
|
||||
Delegate_Name(compute_settings_.tflite_settings().delegate())));
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_ERROR(resize(interpreter_.get()));
|
||||
if (compute_settings_.tflite_settings().cpu_settings().num_threads() != -1) {
|
||||
if (interpreter_->SetNumThreads(
|
||||
compute_settings_.tflite_settings().cpu_settings().num_threads()) !=
|
||||
kTfLiteOk) {
|
||||
return absl::InternalError("Failed setting number of CPU threads");
|
||||
}
|
||||
}
|
||||
SetTfLiteCancellation();
|
||||
|
||||
if (!delegate_) {
|
||||
// Just allocate tensors and return.
|
||||
return AllocateTensors();
|
||||
}
|
||||
|
||||
// The call to ModifyGraphWithDelegate() leaves the interpreter in a usable
|
||||
// state in case of failure: calling AllocateTensors() will silently fallback
|
||||
// on CPU in such a situation.
|
||||
return AllocateTensors();
|
||||
}
|
||||
|
||||
absl::Status TfLiteInterpreterWrapper::InitializeDelegate() {
|
||||
if (delegate_ == nullptr) {
|
||||
Delegate which_delegate = compute_settings_.tflite_settings().delegate();
|
||||
const tflite::ComputeSettings* compute_settings =
|
||||
tflite::ConvertFromProto(compute_settings_, &flatbuffers_builder_);
|
||||
|
||||
if (which_delegate == Delegate::NNAPI) {
|
||||
RETURN_IF_ERROR(
|
||||
LoadDelegatePlugin("Nnapi", *compute_settings->tflite_settings()));
|
||||
} else if (which_delegate == Delegate::HEXAGON) {
|
||||
RETURN_IF_ERROR(
|
||||
LoadDelegatePlugin("Hexagon", *compute_settings->tflite_settings()));
|
||||
} else if (which_delegate == Delegate::GPU) {
|
||||
RETURN_IF_ERROR(
|
||||
LoadDelegatePlugin("Gpu", *compute_settings->tflite_settings()));
|
||||
} else if (which_delegate == Delegate::EDGETPU) {
|
||||
RETURN_IF_ERROR(
|
||||
LoadDelegatePlugin("EdgeTpu", *compute_settings->tflite_settings()));
|
||||
} else if (which_delegate == Delegate::EDGETPU_CORAL) {
|
||||
RETURN_IF_ERROR(LoadDelegatePlugin("EdgeTpuCoral",
|
||||
*compute_settings->tflite_settings()));
|
||||
} else if (which_delegate == Delegate::XNNPACK) {
|
||||
RETURN_IF_ERROR(
|
||||
LoadDelegatePlugin("XNNPack", *compute_settings->tflite_settings()));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
|
||||
const std::function<absl::Status(tflite::Interpreter* interpreter)>&
|
||||
set_inputs) {
|
||||
RETURN_IF_ERROR(set_inputs(interpreter_.get()));
|
||||
return interpreter_->Invoke() != kTfLiteOk
|
||||
? absl::InternalError("TFLite interpreter: Invoke() failed.")
|
||||
: absl::OkStatus();
|
||||
// Reset cancel flag before calling `Invoke()`.
|
||||
cancel_flag_.Set(false);
|
||||
TfLiteStatus status = kTfLiteError;
|
||||
if (fallback_on_execution_error_) {
|
||||
status = InterpreterUtils::InvokeWithCPUFallback(interpreter_.get());
|
||||
} else {
|
||||
status = interpreter_->Invoke();
|
||||
}
|
||||
if (status == kTfLiteOk) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// Assume InvokeWithoutFallback() is guarded under caller's synchronization.
|
||||
// Assume the inference is cancelled successfully if Invoke() returns
|
||||
// kTfLiteError and the cancel flag is `true`.
|
||||
if (status == kTfLiteError && cancel_flag_.Get()) {
|
||||
return absl::CancelledError("Invoke() cancelled.");
|
||||
}
|
||||
if (delegate_) {
|
||||
// Mark that an error occurred so that later invocations immediately
|
||||
// fallback to CPU.
|
||||
got_error_do_not_delegate_anymore_ = true;
|
||||
// InvokeWithCPUFallback returns `kTfLiteDelegateError` in case of
|
||||
// *successful* fallback: convert it to an OK status.
|
||||
if (status == kTfLiteDelegateError) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
return absl::InternalError("Invoke() failed.");
|
||||
}
|
||||
|
||||
absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() {
|
||||
return interpreter_->Invoke() != kTfLiteOk
|
||||
? absl::InternalError("TFLite interpreter: Invoke() failed.")
|
||||
: absl::OkStatus();
|
||||
// Reset cancel flag before calling `Invoke()`.
|
||||
cancel_flag_.Set(false);
|
||||
TfLiteStatus status = interpreter_->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
// Assume InvokeWithoutFallback() is guarded under caller's synchronization.
|
||||
// Assume the inference is cancelled successfully if Invoke() returns
|
||||
// kTfLiteError and the cancel flag is `true`.
|
||||
if (status == kTfLiteError && cancel_flag_.Get()) {
|
||||
return absl::CancelledError("Invoke() cancelled.");
|
||||
}
|
||||
return absl::InternalError("Invoke() failed.");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void TfLiteInterpreterWrapper::Cancel() {
|
||||
// NOP
|
||||
cancel_flag_.Set(true);
|
||||
}
|
||||
|
||||
void TfLiteInterpreterWrapper::SetTfLiteCancellation() {
|
||||
// Create a cancellation check function and set to the TFLite interpreter.
|
||||
auto check_cancel_flag = [](void* data) {
|
||||
auto* cancel_flag = reinterpret_cast<CancelFlag*>(data);
|
||||
return cancel_flag->Get();
|
||||
};
|
||||
interpreter_->SetCancellationFunction(reinterpret_cast<void*>(&cancel_flag_),
|
||||
check_cancel_flag);
|
||||
}
|
||||
|
||||
absl::Status TfLiteInterpreterWrapper::LoadDelegatePlugin(
|
||||
const std::string& name,
|
||||
const tflite::TFLiteSettings& tflite_settings) {
|
||||
delegate_plugin_ = DelegatePluginRegistry::CreateByName(
|
||||
absl::StrFormat("%sPlugin", name), tflite_settings);
|
||||
|
||||
if (delegate_plugin_ == nullptr) {
|
||||
return absl::InternalError(absl::StrFormat(
|
||||
"Could not create %s plugin. Have you linked in the %s_plugin target?",
|
||||
name, name));
|
||||
}
|
||||
|
||||
delegate_ = delegate_plugin_->Create();
|
||||
if (delegate_ == nullptr) {
|
||||
return absl::InternalError(
|
||||
absl::StrFormat("Plugin did not create %s delegate.", name));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
bool TfLiteInterpreterWrapper::HasMiniBenchmarkCompleted() {
|
||||
if (mini_benchmark_ != nullptr &&
|
||||
mini_benchmark_->NumRemainingAccelerationTests() == 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace support
|
||||
|
233
third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
vendored
233
third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
vendored
@ -16,40 +16,123 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/status.h" // from @com_google_absl
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h"
|
||||
#include "tensorflow/lite/experimental/acceleration/configuration/delegate_registry.h"
|
||||
#include "tensorflow/lite/experimental/acceleration/mini_benchmark/mini_benchmark.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
||||
// Wrapper for a TfLiteInterpreter that may be accelerated[1]. This is NOT yet
|
||||
// implemented: this class only provides a first, minimal interface in the
|
||||
// meanwhile.
|
||||
// Options that are created by `TFLiteInterpreterWrapper` and will help to
|
||||
// initialize Interpreter in the callback function. `TFLiteInterpreterWrapper`
|
||||
// retains ownership of the included options, and will ensure that they remain
|
||||
// valid for the duration of the created interpreter's lifetime.
|
||||
struct InterpreterCreationResources {
|
||||
// The delegate created, based on the parameters in `ComputeSettings`.
|
||||
// `TfLiteInterpreterWrapper` exclusively owns the `TfLiteDelegate` object,
|
||||
// and maintains it through out the lifetime of `TfLiteInterpreterWrapper`.
|
||||
TfLiteDelegate* optional_delegate;
|
||||
|
||||
// Number of threads to use, or -1 to use the default number of threads.
|
||||
int num_threads = -1;
|
||||
|
||||
// Apply the InterpreterCreationResources to the InterpreterBuilder.
|
||||
// Note: caller is responsible for ensuring that arguments are valid,
|
||||
// e.g. that num_threads >= -1.
|
||||
void ApplyTo(tflite::InterpreterBuilder* interpreter_builder) const {
|
||||
if (optional_delegate != nullptr) {
|
||||
interpreter_builder->AddDelegate(optional_delegate);
|
||||
}
|
||||
if (num_threads != -1) {
|
||||
// We ignore the TfLiteStatus return value here; caller is responsible
|
||||
// for checking that num_threads is valid.
|
||||
(void)interpreter_builder->SetNumThreads(num_threads);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Wrapper for a TfLiteInterpreter that may be accelerated [1]. Meant to be
|
||||
// substituted for `unique_ptr<tflite::Interpreter>` class members.
|
||||
//
|
||||
// [1] See tensorflow/lite/experimental/acceleration for more details.
|
||||
// This class is in charge of:
|
||||
// * Picking, instantiating and configuring the right delegate for the provided
|
||||
// ComputeSettings [2],
|
||||
// * Providing methods to initialize and invoke the Interpreter with optional
|
||||
// (controlled through the ComputeSettings) automatic fallback to CPU if any
|
||||
// acceleration-related error occurs at compilation or runtime.
|
||||
// * TODO(b/169474250) Cache interpreters for multiple input sizes to enable
|
||||
// performant acceleration for the case where input size changes frequently.
|
||||
//
|
||||
// IMPORTANT: The only supported delegates are (as defined in [1]) NONE, GPU,
|
||||
// HEXAGON and NNAPI. Trying to use this class with EDGETPU or XNNPACK delegates
|
||||
// will cause an UnimplementedError to be thrown at initialization time.
|
||||
//
|
||||
// Like TfLiteInterpreter, this class is thread-compatible. Use from multiple
|
||||
// threads must be guarded by synchronization outside this class.
|
||||
//
|
||||
// [1]:
|
||||
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/acceleration/configuration/configuration.proto
|
||||
class TfLiteInterpreterWrapper {
|
||||
public:
|
||||
TfLiteInterpreterWrapper() = default;
|
||||
// Creates an instance to be associated with a TfLite model that could be
|
||||
// identified by (`default_model_namespace`, `default_model_id`). Note the
|
||||
// model identifier is generally used for the sake of logging.
|
||||
TfLiteInterpreterWrapper(const std::string& default_model_namespace,
|
||||
const std::string& default_model_id);
|
||||
TfLiteInterpreterWrapper()
|
||||
: TfLiteInterpreterWrapper("org.tensorflow.lite.support",
|
||||
"unknown_model_id") {}
|
||||
|
||||
virtual ~TfLiteInterpreterWrapper() = default;
|
||||
|
||||
// Calls `interpreter_initializer` and then `AllocateTensors`. Future
|
||||
// implementation of this method will attempt to apply the provided
|
||||
// `compute_settings` with a graceful fallback in case a failure occurs.
|
||||
// Note: before this gets implemented, do NOT call this method with non-empty
|
||||
// `compute_settings` otherwise an unimplemented error occurs.
|
||||
// Calls `interpreter_initializer` to construct the Interpreter, then
|
||||
// initializes it with the appropriate delegate (if any) specified through
|
||||
// `compute_settings` and finally calls AllocateTensors() on it.
|
||||
//
|
||||
// Whether or not this function automatically falls back to using CPU in case
|
||||
// initialization with a delegate fails depends on the FallbackSettings
|
||||
// specified in the TFLiteSettings of the provided ComputeSettings: if the
|
||||
// `allow_automatic_fallback_on_compilation_error` field is set to true,
|
||||
// fallback will automatically happen; otherwise an InternalError will be
|
||||
// thrown.
|
||||
// This flag allows callers to rely on this function whether or not they
|
||||
// actually want fallback to happen; if they don't, it will ensure that the
|
||||
// configuration doesn't accidentally trigger fallback.
|
||||
//
|
||||
// IMPORTANT: Supported delegate type includes: NONE, NNAPI, GPU, HEXAGON,
|
||||
// XNNPACK, EDGETPU (Google internal), and EDGETPU_CORAL. Specifying another
|
||||
// delegate type may cause an UnimplementedError to be thrown.
|
||||
absl::Status InitializeWithFallback(
|
||||
std::function<absl::Status(const InterpreterCreationResources&,
|
||||
std::unique_ptr<tflite::Interpreter>*)>
|
||||
interpreter_initializer,
|
||||
const tflite::proto::ComputeSettings& compute_settings);
|
||||
|
||||
// Deprecated: Use the one above with `InterpreterCreationResources` instead.
|
||||
absl::Status InitializeWithFallback(
|
||||
std::function<absl::Status(std::unique_ptr<tflite::Interpreter>*)>
|
||||
interpreter_initializer,
|
||||
const tflite::proto::ComputeSettings& compute_settings);
|
||||
|
||||
// Calls `set_inputs` and then Invoke() on the interpreter. Future
|
||||
// implementation of this method will perform a graceful fallback in case a
|
||||
// failure occur due to the `compute_settings` provided at initialization
|
||||
// time.
|
||||
// Calls `set_inputs` and then Invoke() on the interpreter.
|
||||
//
|
||||
// Whether or not this function automatically falls back to using CPU in case
|
||||
// invocation with a delegate fails depends on the FallbackSettings
|
||||
// specified in the TFLiteSettings of the ComputeSettings provided at
|
||||
// initialization: if the `allow_automatic_fallback_on_execution_error`
|
||||
// field is set to true, fallback will automatically happen; otherwise an
|
||||
// InternalError will be thrown.
|
||||
// This flag allows callers to rely on this function whether or not they
|
||||
// actually want fallback to happen; if they don't, it will ensure that the
|
||||
// configuration doesn't accidentally trigger fallback.
|
||||
absl::Status InvokeWithFallback(
|
||||
const std::function<absl::Status(tflite::Interpreter* interpreter)>&
|
||||
set_inputs);
|
||||
@ -58,8 +141,23 @@ class TfLiteInterpreterWrapper {
|
||||
// before-hand.
|
||||
absl::Status InvokeWithoutFallback();
|
||||
|
||||
// Cancels the current running TFLite invocation on CPU. This method is not
|
||||
// yet implemented though it is safe to use it as it acts as a NOP.
|
||||
// Cancels the current TFLite **CPU** inference.
|
||||
//
|
||||
// IMPORTANT: If inference is entirely running on a delegate, this has no
|
||||
// effect; if inference is partially delegated, only the CPU part is
|
||||
// cancelled.
|
||||
//
|
||||
// Usually called on a different thread than the one Invoke() is running
|
||||
// on. Calling Cancel() while InvokeWithFallback() or InvokeWithoutFallback()
|
||||
// is running may cause these methods to return a `CancelledError` with empty
|
||||
// results. Calling Cancel() at any other time doesn't have any effect.
|
||||
//
|
||||
// InvokeWithFallback() and InvokeWithoutFallback() reset the cancel flag
|
||||
// right before the underlying Invoke() is called, so these two methods can be
|
||||
// called again on the same instance after a call to Cancel().
|
||||
//
|
||||
// Note that this is the only method that can be called from another thread
|
||||
// without locking.
|
||||
void Cancel();
|
||||
|
||||
// Accesses the underlying interpreter for other methods.
|
||||
@ -72,8 +170,109 @@ class TfLiteInterpreterWrapper {
|
||||
TfLiteInterpreterWrapper(const TfLiteInterpreterWrapper&) = delete;
|
||||
TfLiteInterpreterWrapper& operator=(const TfLiteInterpreterWrapper&) = delete;
|
||||
|
||||
// Whether an error has occurred with the delegate.
|
||||
bool HasDelegateError() { return got_error_do_not_delegate_anymore_; }
|
||||
|
||||
// Whether the on-device mini-benchmark has completed for those TfLite
|
||||
// acceleration configurations that are specified in passed-in
|
||||
// ComputeSettings. If it finishes, the next time this same InterpreterWrapper
|
||||
// object is created (i.e. w/ the same model and the same
|
||||
// mini-benchmark-related configurations), the best acceleration configuration
|
||||
// will be picked up and used.
|
||||
bool HasMiniBenchmarkCompleted();
|
||||
|
||||
const tflite::proto::ComputeSettings& compute_settings() const {
|
||||
return compute_settings_;
|
||||
}
|
||||
|
||||
protected:
|
||||
// The delegate used to accelerate inference.
|
||||
Interpreter::TfLiteDelegatePtr delegate_;
|
||||
// The corresponding delegate plugin.
|
||||
std::unique_ptr<tflite::delegates::DelegatePluginInterface> delegate_plugin_;
|
||||
|
||||
private:
|
||||
// Performs sanity checks on the provided ComputeSettings.
|
||||
static absl::Status SanityCheckComputeSettings(
|
||||
const tflite::proto::ComputeSettings& compute_settings);
|
||||
|
||||
// Inner function for initializing an interpreter with fallback, optionally
|
||||
// resizing input tensors by calling `resize` on the newly initialized
|
||||
// interpreter.
|
||||
absl::Status InitializeWithFallbackAndResize(
|
||||
std::function<absl::Status(Interpreter* interpreter)> resize =
|
||||
[](Interpreter* interpreter) { return absl::OkStatus(); });
|
||||
|
||||
// Initializes the delegate plugin and creates the delegate.
|
||||
absl::Status InitializeDelegate();
|
||||
|
||||
// Wrapper around the interpreter's `AllocateTensors()` method converting the
|
||||
// returned `TfLiteStatus` to an `absl::Status`.
|
||||
absl::Status AllocateTensors();
|
||||
|
||||
absl::Status LoadDelegatePlugin(const std::string&,
|
||||
const tflite::TFLiteSettings&);
|
||||
|
||||
std::string ModelNamespace();
|
||||
std::string ModelID();
|
||||
|
||||
// The interpreter instance being used.
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
// The function used to initialize the interpreter and store it into the
|
||||
// provided `std::unique_ptr`.
|
||||
// This is typically a wrapper function around `tflite::InterpreterBuilder`,
|
||||
// giving the caller the opportunity to hook-up a custom `tflite::OpResolver`
|
||||
// and / or `tflite::ErrorReporter`.
|
||||
std::function<absl::Status(const InterpreterCreationResources&,
|
||||
std::unique_ptr<Interpreter>*)>
|
||||
interpreter_initializer_;
|
||||
|
||||
// The ComputeSettings provided at initialization time.
|
||||
// Note when TfLite mini-benchmark is enabled, it could be changed to the
|
||||
// best TfLite acceleration setting selected.
|
||||
tflite::proto::ComputeSettings compute_settings_;
|
||||
|
||||
// Set to true if an occurs with the specified delegate (if any), causing
|
||||
// future calls to fallback on CPU.
|
||||
bool got_error_do_not_delegate_anymore_;
|
||||
|
||||
// Fallback behavior as specified through the ComputeSettings.
|
||||
bool fallback_on_compilation_error_;
|
||||
bool fallback_on_execution_error_;
|
||||
|
||||
std::string default_model_namespace_;
|
||||
std::string default_model_id_;
|
||||
|
||||
// Used to convert the ComputeSettings proto to FlatBuffer format.
|
||||
flatbuffers::FlatBufferBuilder flatbuffers_builder_;
|
||||
|
||||
// Cancellation flag definition.
|
||||
struct CancelFlag {
|
||||
// Mutex to guard the `cancel_flag`.
|
||||
mutable absl::Mutex cancel_mutex;
|
||||
|
||||
// A flag indicates if the caller cancels the TFLite interpreter invocation.
|
||||
bool cancel_flag ABSL_GUARDED_BY(cancel_mutex) = false;
|
||||
|
||||
// Returns `cancel_flag`.
|
||||
bool Get() const ABSL_LOCKS_EXCLUDED(cancel_mutex) {
|
||||
absl::MutexLock cancel_lock(&cancel_mutex);
|
||||
return cancel_flag;
|
||||
}
|
||||
|
||||
// Sets `cancel_flag` to `value`.
|
||||
void Set(bool value) ABSL_LOCKS_EXCLUDED(cancel_mutex) {
|
||||
absl::MutexLock cancel_lock(&cancel_mutex);
|
||||
cancel_flag = value;
|
||||
}
|
||||
};
|
||||
CancelFlag cancel_flag_;
|
||||
|
||||
std::unique_ptr<tflite::acceleration::MiniBenchmark> mini_benchmark_;
|
||||
|
||||
// Sets up the TFLite invocation cancellation by
|
||||
// tflite::Interpreter::SetCancellationFunction().
|
||||
void SetTfLiteCancellation();
|
||||
};
|
||||
|
||||
} // namespace support
|
||||
|
@ -37,6 +37,19 @@ typedef unsigned long uword_t;
|
||||
#define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also.
|
||||
#define GG_LL_FORMAT_W L"ll"
|
||||
|
||||
const uint8 kuint8max{0xFF};
|
||||
const uint16 kuint16max{0xFFFF};
|
||||
const uint32 kuint32max{0xFFFFFFFF};
|
||||
const uint64 kuint64max{GG_ULONGLONG(0xFFFFFFFFFFFFFFFF)};
|
||||
const int8 kint8min{~0x7F};
|
||||
const int8 kint8max{0x7F};
|
||||
const int16 kint16min{~0x7FFF};
|
||||
const int16 kint16max{0x7FFF};
|
||||
const int32 kint32min{~0x7FFFFFFF};
|
||||
const int32 kint32max{0x7FFFFFFF};
|
||||
const int64 kint64min{GG_LONGLONG(~0x7FFFFFFFFFFFFFFF)};
|
||||
const int64 kint64max{GG_LONGLONG(0x7FFFFFFFFFFFFFFF)};
|
||||
|
||||
typedef uint64 Fprint;
|
||||
static const Fprint kIllegalFprint = 0;
|
||||
static const Fprint kMaxFprint = GG_ULONGLONG(0xFFFFFFFFFFFFFFFF);
|
||||
|
21
third_party/tflite_support/src/tensorflow_lite_support/cc/port/status_matchers.h
vendored
Normal file
21
third_party/tflite_support/src/tensorflow_lite_support/cc/port/status_matchers.h
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MATCHERS_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MATCHERS_H_
|
||||
|
||||
#include "tensorflow_lite_support/cc/port/default/status_matchers.h"
|
||||
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MATCHERS_H_
|
@ -16,5 +16,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_
|
||||
#define TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_
|
||||
|
||||
#include "tensorflow_lite_support/cc/port/default/statusor.h"
|
||||
// This header file is used to manage the depended StatusOr library. It creates
|
||||
// an extra layer that makes it easier to switch between the desired version of
|
||||
// StatusOr.
|
||||
#include "absl/status/statusor.h" // from @com_google_absl
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
||||
template <typename T>
|
||||
using StatusOr = absl::StatusOr<T>;
|
||||
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_
|
||||
|
@ -15,8 +15,8 @@ answer question based on context.
|
||||
Use the C++ API to answer questions as follows:
|
||||
|
||||
```cc
|
||||
using tflite::task::text::qa::BertQuestionAnswerer;
|
||||
using tflite::task::text::qa::QaAnswer;
|
||||
using tflite::task::text::BertQuestionAnswerer;
|
||||
using tflite::task::text::QaAnswer;
|
||||
// Create API handler with Mobile Bert model.
|
||||
auto qa_client = BertQuestionAnswerer::CreateBertQuestionAnswererFromFile("/path/to/mobileBertModel", "/path/to/vocab");
|
||||
// Or create API handler with Albert model.
|
||||
|
65
third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/BUILD
vendored
Normal file
65
third_party/tflite_support/src/tensorflow_lite_support/cc/task/audio/BUILD
vendored
Normal file
@ -0,0 +1,65 @@
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl",
|
||||
"cc_library_with_tflite",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library_with_tflite(
|
||||
name = "audio_classifier",
|
||||
srcs = ["audio_classifier.cc"],
|
||||
hdrs = ["audio_classifier.h"],
|
||||
tflite_deps = [
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||
"//tensorflow_lite_support/cc/task/processor:classification_postprocessor",
|
||||
"//tensorflow_lite_support/cc/task/processor:audio_preprocessor",
|
||||
"//tensorflow_lite_support/cc/task/core:base_task_api",
|
||||
"//tensorflow_lite_support/cc/task/core:task_api_factory",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow_lite_support/cc:common",
|
||||
"//tensorflow_lite_support/cc/port:integral_types",
|
||||
"//tensorflow_lite_support/cc/port:statusor",
|
||||
"//tensorflow_lite_support/cc/task/audio/core:audio_buffer",
|
||||
"//tensorflow_lite_support/cc/task/audio/proto:audio_classifier_options_cc_proto",
|
||||
"//tensorflow_lite_support/cc/task/audio/proto:class_proto_inc",
|
||||
"//tensorflow_lite_support/cc/task/audio/proto:classifications_proto_inc",
|
||||
"//tensorflow_lite_support/cc/task/core:classification_head",
|
||||
"//tensorflow_lite_support/cc/task/core:label_map_item",
|
||||
"//tensorflow_lite_support/cc/task/core:task_utils",
|
||||
"//tensorflow_lite_support/cc/task/processor/proto:classification_options_cc_proto",
|
||||
"//tensorflow_lite_support/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library_with_tflite(
|
||||
name = "audio_embedder",
|
||||
srcs = ["audio_embedder.cc"],
|
||||
hdrs = ["audio_embedder.h"],
|
||||
tflite_deps = [
|
||||
"//tensorflow_lite_support/cc/task/processor:embedding_postprocessor",
|
||||
"//tensorflow_lite_support/cc/task/processor:audio_preprocessor",
|
||||
"//tensorflow_lite_support/cc/task/core:task_api_factory",
|
||||
"//tensorflow_lite_support/cc/task/core:base_task_api",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow_lite_support/cc:common",
|
||||
"//tensorflow_lite_support/cc/port:status_macros",
|
||||
"//tensorflow_lite_support/cc/port:statusor",
|
||||
"//tensorflow_lite_support/cc/task/audio/proto:audio_embedder_options_cc_proto",
|
||||
"//tensorflow_lite_support/cc/task/processor/proto:embedding_cc_proto",
|
||||
"//tensorflow_lite_support/cc/task/processor/proto:embedding_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
],
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user