-
Notifications
You must be signed in to change notification settings - Fork 124
/
metadata.py
865 lines (687 loc) · 30.6 KB
/
metadata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite metadata tools."""
import copy
import inspect
import io
import os
import shutil
import sys
import tempfile
import warnings
import zipfile
import flatbuffers
from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb
from tensorflow_lite_support.metadata.cc.python import _pywrap_metadata_version
from tensorflow_lite_support.metadata.flatbuffers_lib import _pywrap_flatbuffers
try:
# If exists, optionally use TensorFlow to open and check files. Used to
# support more than local file systems.
# In pip requirements, we doesn't necessarily need tensorflow as a dep.
import tensorflow as tf # pylint: disable=g-import-not-at-top
_open_file = tf.io.gfile.GFile
_exists_file = tf.io.gfile.exists
except ImportError as e:
# If TensorFlow package doesn't exist, fall back to original open and exists.
_open_file = open
_exists_file = os.path.exists
def _maybe_open_as_binary(filename, mode):
"""Maybe open the binary file, and returns a file-like."""
if hasattr(filename, "read"): # A file-like has read().
return filename
openmode = mode if "b" in mode else mode + "b" # Add binary explicitly.
return _open_file(filename, openmode)
def _open_as_zipfile(filename, mode="r"):
"""Open file as a zipfile.
Args:
filename: str or file-like or path-like, to the zipfile.
mode: str, common file mode for zip.
(See: https://docs.python.org/3/library/zipfile.html)
Returns:
A ZipFile object.
"""
file_like = _maybe_open_as_binary(filename, mode)
return zipfile.ZipFile(file_like, mode)
def _is_zipfile(filename):
"""Checks whether it is a zipfile."""
with _maybe_open_as_binary(filename, "r") as f:
return zipfile.is_zipfile(f)
def get_path_to_datafile(path):
"""Gets the path to the specified file in the data dependencies.
The path is relative to the file calling the function.
It's a simple replacement of
"tensorflow.python.platform.resource_loader.get_path_to_datafile".
Args:
path: a string resource path relative to the calling file.
Returns:
The path to the specified file present in the data attribute of py_test
or py_binary.
"""
data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) # pylint: disable=protected-access
return os.path.join(data_files_path, path)
_FLATC_TFLITE_METADATA_SCHEMA_FILE = get_path_to_datafile(
"../metadata_schema.fbs")
# TODO(b/141467403): add delete method for associated files.
class MetadataPopulator(object):
"""Packs metadata and associated files into TensorFlow Lite model file.
MetadataPopulator can be used to populate metadata and model associated files
into a model file or a model buffer (in bytearray). It can also help to
inspect list of files that have been packed into the model or are supposed to
be packed into the model.
The metadata file (or buffer) should be generated based on the metadata
schema:
third_party/tensorflow/lite/schema/metadata_schema.fbs
Example usage:
Populate matadata and label file into an image classifier model.
First, based on metadata_schema.fbs, generate the metadata for this image
classifer model using Flatbuffers API. Attach the label file onto the ouput
tensor (the tensor of probabilities) in the metadata.
Then, pack the metadata and label file into the model as follows.
```python
# Populating a metadata file (or a metadta buffer) and associated files to
a model file:
populator = MetadataPopulator.with_model_file(model_file)
# For metadata buffer (bytearray read from the metadata file), use:
# populator.load_metadata_buffer(metadata_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
# For associated file buffer (bytearray read from the file), use:
# populator.load_associated_file_buffers({"label.txt": b"file content"})
populator.populate()
# Populating a metadata file (or a metadta buffer) and associated files to
a model buffer:
populator = MetadataPopulator.with_model_buffer(model_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
populator.populate()
# Writing the updated model buffer into a file.
updated_model_buf = populator.get_model_buffer()
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
# Transferring metadata and associated files from another TFLite model:
populator = MetadataPopulator.with_model_buffer(model_buf)
populator_dst.load_metadata_and_associated_files(src_model_buf)
populator_dst.populate()
updated_model_buf = populator.get_model_buffer()
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
```
Note that existing metadata buffer (if applied) will be overridden by the new
metadata buffer.
"""
# As Zip API is used to concatenate associated files after tflite model file,
# the populating operation is developed based on a model file. For in-memory
# model buffer, we create a tempfile to serve the populating operation.
# Creating the deleting such a tempfile is handled by the class,
# _MetadataPopulatorWithBuffer.
METADATA_FIELD_NAME = "TFLITE_METADATA"
TFLITE_FILE_IDENTIFIER = b"TFL3"
METADATA_FILE_IDENTIFIER = b"M001"
def __init__(self, model_file):
"""Constructor for MetadataPopulator.
Args:
model_file: valid path to a TensorFlow Lite model file.
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
_assert_model_file_identifier(model_file)
self._model_file = model_file
self._metadata_buf = None
# _associated_files is a dict of file name and file buffer.
self._associated_files = {}
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataPopulator object that populates data to a model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataPopulator object.
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
return cls(model_file)
# TODO(b/141468993): investigate if type check can be applied to model_buf for
# FB.
@classmethod
def with_model_buffer(cls, model_buf):
"""Creates a MetadataPopulator object that populates data to a model buffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Returns:
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
Raises:
ValueError: the model does not have the expected flatbuffer identifer.
"""
return _MetadataPopulatorWithBuffer(model_buf)
def get_model_buffer(self):
"""Gets the buffer of the model with packed metadata and associated files.
Returns:
Model buffer (in bytearray).
"""
with _open_file(self._model_file, "rb") as f:
return f.read()
def get_packed_associated_file_list(self):
"""Gets a list of associated files packed to the model file.
Returns:
List of packed associated files.
"""
if not _is_zipfile(self._model_file):
return []
with _open_as_zipfile(self._model_file, "r") as zf:
return zf.namelist()
def get_recorded_associated_file_list(self):
"""Gets a list of associated files recorded in metadata of the model file.
Associated files may be attached to a model, a subgraph, or an input/output
tensor.
Returns:
List of recorded associated files.
"""
if not self._metadata_buf:
return []
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(
self._metadata_buf, 0))
return [
file.name.decode("utf-8")
for file in self._get_recorded_associated_file_object_list(metadata)
]
def load_associated_file_buffers(self, associated_files):
"""Loads the associated file buffers (in bytearray) to be populated.
Args:
associated_files: a dictionary of associated file names and corresponding
file buffers, such as {"file.txt": b"file content"}. If pass in file
paths for the file name, only the basename will be populated.
"""
self._associated_files.update({
os.path.basename(name): buffers
for name, buffers in associated_files.items()
})
def load_associated_files(self, associated_files):
"""Loads associated files that to be concatenated after the model file.
Args:
associated_files: list of file paths.
Raises:
IOError:
File not found.
"""
for af_name in associated_files:
_assert_file_exist(af_name)
with _open_file(af_name, "rb") as af:
self.load_associated_file_buffers({af_name: af.read()})
def load_metadata_buffer(self, metadata_buf):
"""Loads the metadata buffer (in bytearray) to be populated.
Args:
metadata_buf: metadata buffer (in bytearray) to be populated.
Raises:
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
if not metadata_buf:
raise ValueError("The metadata to be populated is empty.")
self._validate_metadata(metadata_buf)
# Gets the minimum metadata parser version of the metadata_buf.
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
bytes(metadata_buf))
# Inserts in the minimum metadata parser version into the metadata_buf.
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
metadata.minParserVersion = min_version
# Remove local file directory in the `name` field of `AssociatedFileT`, and
# make it consistent with the name of the actual file packed in the model.
self._use_basename_for_associated_files_in_metadata(metadata)
b = flatbuffers.Builder(0)
b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
metadata_buf_with_version = b.Output()
self._metadata_buf = metadata_buf_with_version
def load_metadata_file(self, metadata_file):
"""Loads the metadata file to be populated.
Args:
metadata_file: path to the metadata file to be populated.
Raises:
IOError: File not found.
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
_assert_file_exist(metadata_file)
with _open_file(metadata_file, "rb") as f:
metadata_buf = f.read()
self.load_metadata_buffer(bytearray(metadata_buf))
def load_metadata_and_associated_files(self, src_model_buf):
"""Loads the metadata and associated files from another model buffer.
Args:
src_model_buf: source model buffer (in bytearray) with metadata and
associated files.
"""
# Load the model metadata from src_model_buf if exist.
metadata_buffer = get_metadata_buffer(src_model_buf)
if metadata_buffer:
self.load_metadata_buffer(metadata_buffer)
# Load the associated files from src_model_buf if exist.
if _is_zipfile(io.BytesIO(src_model_buf)):
with _open_as_zipfile(io.BytesIO(src_model_buf)) as zf:
self.load_associated_file_buffers(
{f: zf.read(f) for f in zf.namelist()})
def populate(self):
"""Populates loaded metadata and associated files into the model file."""
self._assert_validate()
self._populate_metadata_buffer()
self._populate_associated_files()
def _assert_validate(self):
"""Validates the metadata and associated files to be populated.
Raises:
ValueError:
File is recorded in the metadata, but is not going to be populated.
File has already been packed.
"""
# Gets files that are recorded in metadata.
recorded_files = self.get_recorded_associated_file_list()
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
# Gets the file name of those associated files to be populated.
to_be_populated_files = self._associated_files.keys()
# Checks all files recorded in the metadata will be populated.
for rf in recorded_files:
if rf not in to_be_populated_files and rf not in packed_files:
raise ValueError("File, '{0}', is recorded in the metadata, but has "
"not been loaded into the populator.".format(rf))
for f in to_be_populated_files:
if f in packed_files:
raise ValueError("File, '{0}', has already been packed.".format(f))
if f not in recorded_files:
warnings.warn(
"File, '{0}', does not exist in the metadata. But packing it to "
"tflite model is still allowed.".format(f))
def _copy_archived_files(self, src_zip, file_list, dst_zip):
"""Copy archieved files in file_list from src_zip ro dst_zip."""
if not _is_zipfile(src_zip):
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
with _open_as_zipfile(src_zip, "r") as src_zf, \
_open_as_zipfile(dst_zip, "a") as dst_zf:
src_list = src_zf.namelist()
for f in file_list:
if f not in src_list:
raise ValueError(
"File, '{0}', does not exist in the zipfile, {1}.".format(
f, src_zip))
file_buffer = src_zf.read(f)
dst_zf.writestr(f, file_buffer)
def _get_associated_files_from_process_units(self, table, field_name):
"""Gets the files that are attached the process units field of a table.
Args:
table: a Flatbuffers table object that contains fields of an array of
ProcessUnit, such as TensorMetadata and SubGraphMetadata.
field_name: the name of the field in the table that represents an array of
ProcessUnit. If the table is TensorMetadata, field_name can be
"ProcessUnits". If the table is SubGraphMetadata, field_name can be
either "InputProcessUnits" or "OutputProcessUnits".
Returns:
A list of AssociatedFileT objects.
"""
if table is None:
return []
file_list = []
process_units = getattr(table, field_name)
# If the process_units field is not populated, it will be None. Use an
# empty list to skip the check.
for process_unit in process_units or []:
options = process_unit.options
if isinstance(options, (_metadata_fb.BertTokenizerOptionsT,
_metadata_fb.RegexTokenizerOptionsT)):
file_list += self._get_associated_files_from_table(options, "vocabFile")
elif isinstance(options, _metadata_fb.SentencePieceTokenizerOptionsT):
file_list += self._get_associated_files_from_table(
options, "sentencePieceModel")
file_list += self._get_associated_files_from_table(options, "vocabFile")
return file_list
def _get_associated_files_from_table(self, table, field_name):
"""Gets the associated files that are attached a table directly.
Args:
table: a Flatbuffers table object that contains fields of an array of
AssociatedFile, such as TensorMetadata and BertTokenizerOptions.
field_name: the name of the field in the table that represents an array of
ProcessUnit. If the table is TensorMetadata, field_name can be
"AssociatedFiles". If the table is BertTokenizerOptions, field_name can
be "VocabFile".
Returns:
A list of AssociatedFileT objects.
"""
if table is None:
return []
# If the associated file field is not populated,
# `getattr(table, field_name)` will be None. Return an empty list.
return getattr(table, field_name) or []
def _get_recorded_associated_file_object_list(self, metadata):
"""Gets a list of AssociatedFileT objects recorded in the metadata.
Associated files may be attached to a model, a subgraph, or an input/output
tensor.
Args:
metadata: the ModelMetadataT object.
Returns:
List of recorded AssociatedFileT objects.
"""
recorded_files = []
# Add associated files attached to ModelMetadata.
recorded_files += self._get_associated_files_from_table(
metadata, "associatedFiles")
# Add associated files attached to each SubgraphMetadata.
for subgraph in metadata.subgraphMetadata or []:
recorded_files += self._get_associated_files_from_table(
subgraph, "associatedFiles")
# Add associated files attached to each input tensor.
for tensor_metadata in subgraph.inputTensorMetadata or []:
recorded_files += self._get_associated_files_from_table(
tensor_metadata, "associatedFiles")
recorded_files += self._get_associated_files_from_process_units(
tensor_metadata, "processUnits")
# Add associated files attached to each output tensor.
for tensor_metadata in subgraph.outputTensorMetadata or []:
recorded_files += self._get_associated_files_from_table(
tensor_metadata, "associatedFiles")
recorded_files += self._get_associated_files_from_process_units(
tensor_metadata, "processUnits")
# Add associated files attached to the input_process_units.
recorded_files += self._get_associated_files_from_process_units(
subgraph, "inputProcessUnits")
# Add associated files attached to the output_process_units.
recorded_files += self._get_associated_files_from_process_units(
subgraph, "outputProcessUnits")
return recorded_files
def _populate_associated_files(self):
"""Concatenates associated files after TensorFlow Lite model file.
If the MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
"""
# Opens up the model file in "appending" mode.
# If self._model_file already has pack files, zipfile will concatenate
# addition files after self._model_file. For example, suppose we have
# self._model_file = old_tflite_file | label1.txt | label2.txt
# Then after trigger populate() to add label3.txt, self._model_file becomes
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
with tempfile.SpooledTemporaryFile() as temp:
# (1) Copy content from model file of to temp file.
with _open_file(self._model_file, "rb") as f:
shutil.copyfileobj(f, temp)
# (2) Append of to a temp file as a zip.
with _open_as_zipfile(temp, "a") as zf:
for file_name, file_buffer in self._associated_files.items():
zf.writestr(file_name, file_buffer)
# (3) Copy temp file to model file.
temp.seek(0)
with _open_file(self._model_file, "wb") as f:
shutil.copyfileobj(temp, f)
def _populate_metadata_buffer(self):
"""Populates the metadata buffer (in bytearray) into the model file.
Inserts metadata_buf into the metadata field of schema.Model. If the
MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
Existing metadata buffer (if applied) will be overridden by the new metadata
buffer.
"""
with _open_file(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.ModelT.InitFromObj(
_schema_fb.Model.GetRootAsModel(model_buf, 0))
buffer_field = _schema_fb.BufferT()
buffer_field.data = self._metadata_buf
is_populated = False
if not model.metadata:
model.metadata = []
else:
# Check if metadata has already been populated.
for meta in model.metadata:
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
is_populated = True
model.buffers[meta.buffer] = buffer_field
if not is_populated:
if not model.buffers:
model.buffers = []
model.buffers.append(buffer_field)
# Creates a new metadata field.
metadata_field = _schema_fb.MetadataT()
metadata_field.name = self.METADATA_FIELD_NAME
metadata_field.buffer = len(model.buffers) - 1
model.metadata.append(metadata_field)
# Packs model back to a flatbuffer binaray file.
b = flatbuffers.Builder(0)
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
model_buf = b.Output()
# Saves the updated model buffer to model file.
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
if packed_files:
# Writes the updated model buffer and associated files into a new model
# file (in memory). Then overwrites the original model file.
with tempfile.SpooledTemporaryFile() as temp:
temp.write(model_buf)
self._copy_archived_files(self._model_file, packed_files, temp)
temp.seek(0)
with _open_file(self._model_file, "wb") as f:
shutil.copyfileobj(temp, f)
else:
with _open_file(self._model_file, "wb") as f:
f.write(model_buf)
def _use_basename_for_associated_files_in_metadata(self, metadata):
"""Removes any associated file local directory (if exists)."""
for file in self._get_recorded_associated_file_object_list(metadata):
file.name = os.path.basename(file.name)
def _validate_metadata(self, metadata_buf):
"""Validates the metadata to be populated."""
_assert_metadata_buffer_identifier(metadata_buf)
# Verify the number of SubgraphMetadata is exactly one.
# TFLite currently only support one subgraph.
model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
metadata_buf, 0)
if model_meta.SubgraphMetadataLength() != 1:
raise ValueError("The number of SubgraphMetadata should be exactly one, "
"but got {0}.".format(
model_meta.SubgraphMetadataLength()))
# Verify if the number of tensor metadata matches the number of tensors.
with _open_file(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
num_input_tensors = model.Subgraphs(0).InputsLength()
num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength()
if num_input_tensors != num_input_meta:
raise ValueError(
"The number of input tensors ({0}) should match the number of "
"input tensor metadata ({1})".format(num_input_tensors,
num_input_meta))
num_output_tensors = model.Subgraphs(0).OutputsLength()
num_output_meta = model_meta.SubgraphMetadata(
0).OutputTensorMetadataLength()
if num_output_tensors != num_output_meta:
raise ValueError(
"The number of output tensors ({0}) should match the number of "
"output tensor metadata ({1})".format(num_output_tensors,
num_output_meta))
class _MetadataPopulatorWithBuffer(MetadataPopulator):
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
This class is used to populate metadata into a in-memory model buffer. As we
use Zip API to concatenate associated files after tflite model file, the
populating operation is developed based on a model file. For in-memory model
buffer, we create a tempfile to serve the populating operation. This class is
then used to generate this tempfile, and delete the file when the
MetadataPopulator object is deleted.
"""
def __init__(self, model_buf):
"""Constructor for _MetadataPopulatorWithBuffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Raises:
ValueError: model_buf is empty.
ValueError: model_buf does not have the expected flatbuffer identifer.
"""
if not model_buf:
raise ValueError("model_buf cannot be empty.")
with tempfile.NamedTemporaryFile() as temp:
model_file = temp.name
with _open_file(model_file, "wb") as f:
f.write(model_buf)
super().__init__(model_file)
def __del__(self):
"""Destructor of _MetadataPopulatorWithBuffer.
Deletes the tempfile.
"""
if os.path.exists(self._model_file):
os.remove(self._model_file)
class MetadataDisplayer(object):
"""Displays metadata and associated file info in human-readable format."""
def __init__(self, model_buffer, metadata_buffer, associated_file_list):
"""Constructor for MetadataDisplayer.
Args:
model_buffer: valid buffer of the model file.
metadata_buffer: valid buffer of the metadata file.
associated_file_list: list of associate files in the model file.
"""
_assert_model_buffer_identifier(model_buffer)
_assert_metadata_buffer_identifier(metadata_buffer)
self._model_buffer = model_buffer
self._metadata_buffer = metadata_buffer
self._associated_file_list = associated_file_list
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataDisplayer object for the model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataDisplayer object.
Raises:
IOError: File not found.
ValueError: The model does not have metadata.
"""
_assert_file_exist(model_file)
with _open_file(model_file, "rb") as f:
return cls.with_model_buffer(f.read())
@classmethod
def with_model_buffer(cls, model_buffer):
"""Creates a MetadataDisplayer object for a file buffer.
Args:
model_buffer: TensorFlow Lite model buffer in bytearray.
Returns:
MetadataDisplayer object.
"""
if not model_buffer:
raise ValueError("model_buffer cannot be empty.")
metadata_buffer = get_metadata_buffer(model_buffer)
if not metadata_buffer:
raise ValueError("The model does not have metadata.")
associated_file_list = cls._parse_packed_associted_file_list(model_buffer)
return cls(model_buffer, metadata_buffer, associated_file_list)
def get_associated_file_buffer(self, filename):
"""Get the specified associated file content in bytearray.
Args:
filename: name of the file to be extracted.
Returns:
The file content in bytearray.
Raises:
ValueError: if the file does not exist in the model.
"""
if filename not in self._associated_file_list:
raise ValueError(
"The file, {}, does not exist in the model.".format(filename))
with _open_as_zipfile(io.BytesIO(self._model_buffer)) as zf:
return zf.read(filename)
def get_metadata_buffer(self):
"""Get the metadata buffer in bytearray out from the model."""
return copy.deepcopy(self._metadata_buffer)
def get_metadata_json(self):
"""Converts the metadata into a json string."""
return convert_to_json(self._metadata_buffer)
def get_packed_associated_file_list(self):
"""Returns a list of associated files that are packed in the model.
Returns:
A name list of associated files.
"""
return copy.deepcopy(self._associated_file_list)
@staticmethod
def _parse_packed_associted_file_list(model_buf):
"""Gets a list of associated files packed to the model file.
Args:
model_buf: valid file buffer.
Returns:
List of packed associated files.
"""
try:
with _open_as_zipfile(io.BytesIO(model_buf)) as zf:
return zf.namelist()
except zipfile.BadZipFile:
return []
# Create an individual method for getting the metadata json file, so that it can
# be used as a standalone util.
def convert_to_json(metadata_buffer):
"""Converts the metadata into a json string.
Args:
metadata_buffer: valid metadata buffer in bytes.
Returns:
Metadata in JSON format.
Raises:
ValueError: error occured when parsing the metadata schema file.
"""
opt = _pywrap_flatbuffers.IDLOptions()
opt.strict_json = True
parser = _pywrap_flatbuffers.Parser(opt)
with _open_file(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f:
metadata_schema_content = f.read()
if not parser.parse(metadata_schema_content):
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
return _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
def _assert_file_exist(filename):
"""Checks if a file exists."""
if not _exists_file(filename):
raise IOError("File, '{0}', does not exist.".format(filename))
def _assert_model_file_identifier(model_file):
"""Checks if a model file has the expected TFLite schema identifier."""
_assert_file_exist(model_file)
with _open_file(model_file, "rb") as f:
_assert_model_buffer_identifier(f.read())
def _assert_model_buffer_identifier(model_buf):
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
raise ValueError(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.")
def _assert_metadata_buffer_identifier(metadata_buf):
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
metadata_buf, 0):
raise ValueError(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.")
def get_metadata_buffer(model_buf):
"""Returns the metadata in the model file as a buffer.
Args:
model_buf: valid buffer of the model file.
Returns:
Metadata buffer. Returns `None` if the model does not have metadata.
"""
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
# Gets metadata from the model file.
for i in range(tflite_model.MetadataLength()):
meta = tflite_model.Metadata(i)
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
buffer_index = meta.Buffer()
metadata = tflite_model.Buffers(buffer_index)
return metadata.DataAsNumpy().tobytes()
return None