create_ml_io.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright (c) <2015-Present> Tzutalin
  2. # Copyright (C) 2013 MIT, Computer Science and Artificial Intelligence Laboratory. Bryan Russell, Antonio Torralba,
  3. # William T. Freeman. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
  4. # associated documentation files (the "Software"), to deal in the Software without restriction, including without
  5. # limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
  6. # Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
  7. # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
  8. # the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
  9. # NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
  10. # SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
  11. # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  12. # THE SOFTWARE.
  13. #!/usr/bin/env python
  14. # -*- coding: utf8 -*-
  15. import json
  16. from pathlib import Path
  17. from libs.constants import DEFAULT_ENCODING
  18. import os
  19. JSON_EXT = '.json'
  20. ENCODE_METHOD = DEFAULT_ENCODING
  21. class CreateMLWriter:
  22. def __init__(self, foldername, filename, imgsize, shapes, outputfile, databasesrc='Unknown', localimgpath=None):
  23. self.foldername = foldername
  24. self.filename = filename
  25. self.databasesrc = databasesrc
  26. self.imgsize = imgsize
  27. self.boxlist = []
  28. self.localimgpath = localimgpath
  29. self.verified = False
  30. self.shapes = shapes
  31. self.outputfile = outputfile
  32. def write(self):
  33. if os.path.isfile(self.outputfile):
  34. with open(self.outputfile, "r") as file:
  35. input_data = file.read()
  36. outputdict = json.loads(input_data)
  37. else:
  38. outputdict = []
  39. outputimagedict = {
  40. "image": self.filename,
  41. "annotations": []
  42. }
  43. for shape in self.shapes:
  44. points = shape["points"]
  45. x1 = points[0][0]
  46. y1 = points[0][1]
  47. x2 = points[1][0]
  48. y2 = points[2][1]
  49. height, width, x, y = self.calculate_coordinates(x1, x2, y1, y2)
  50. shapedict = {
  51. "label": shape["label"],
  52. "coordinates": {
  53. "x": x,
  54. "y": y,
  55. "width": width,
  56. "height": height
  57. }
  58. }
  59. outputimagedict["annotations"].append(shapedict)
  60. # check if image already in output
  61. exists = False
  62. for i in range(0, len(outputdict)):
  63. if outputdict[i]["image"] == outputimagedict["image"]:
  64. exists = True
  65. outputdict[i] = outputimagedict
  66. break
  67. if not exists:
  68. outputdict.append(outputimagedict)
  69. Path(self.outputfile).write_text(json.dumps(outputdict), ENCODE_METHOD)
  70. def calculate_coordinates(self, x1, x2, y1, y2):
  71. if x1 < x2:
  72. xmin = x1
  73. xmax = x2
  74. else:
  75. xmin = x2
  76. xmax = x1
  77. if y1 < y2:
  78. ymin = y1
  79. ymax = y2
  80. else:
  81. ymin = y2
  82. ymax = y1
  83. width = xmax - xmin
  84. if width < 0:
  85. width = width * -1
  86. height = ymax - ymin
  87. # x and y from center of rect
  88. x = xmin + width / 2
  89. y = ymin + height / 2
  90. return height, width, x, y
  91. class CreateMLReader:
  92. def __init__(self, jsonpath, filepath):
  93. self.jsonpath = jsonpath
  94. self.shapes = []
  95. self.verified = False
  96. self.filename = filepath.split("/")[-1:][0]
  97. try:
  98. self.parse_json()
  99. except ValueError:
  100. print("JSON decoding failed")
  101. def parse_json(self):
  102. with open(self.jsonpath, "r") as file:
  103. inputdata = file.read()
  104. outputdict = json.loads(inputdata)
  105. self.verified = True
  106. if len(self.shapes) > 0:
  107. self.shapes = []
  108. for image in outputdict:
  109. if image["image"] == self.filename:
  110. for shape in image["annotations"]:
  111. self.add_shape(shape["label"], shape["coordinates"])
  112. def add_shape(self, label, bndbox):
  113. xmin = bndbox["x"] - (bndbox["width"] / 2)
  114. ymin = bndbox["y"] - (bndbox["height"] / 2)
  115. xmax = bndbox["x"] + (bndbox["width"] / 2)
  116. ymax = bndbox["y"] + (bndbox["height"] / 2)
  117. points = [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)]
  118. self.shapes.append((label, points, None, None, True))
  119. def get_shapes(self):
  120. return self.shapes