app.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. from typing import Any, Generator
  7. from app_conf import (
  8. GALLERY_PATH,
  9. GALLERY_PREFIX,
  10. POSTERS_PATH,
  11. POSTERS_PREFIX,
  12. UPLOADS_PATH,
  13. UPLOADS_PREFIX,
  14. )
  15. from data.loader import preload_data
  16. from data.schema import schema
  17. from data.store import set_videos
  18. from flask import Flask, make_response, Request, request, Response, send_from_directory
  19. from flask_cors import CORS
  20. from inference.data_types import PropagateDataResponse, PropagateInVideoRequest
  21. from inference.multipart import MultipartResponseBuilder
  22. from inference.predictor import InferenceAPI
  23. from strawberry.flask.views import GraphQLView
  24. logger = logging.getLogger(__name__)
  25. app = Flask(__name__)
  26. cors = CORS(app, supports_credentials=True)
  27. videos = preload_data()
  28. set_videos(videos)
  29. inference_api = InferenceAPI()
  30. @app.route("/healthy")
  31. def healthy() -> Response:
  32. return make_response("OK", 200)
  33. @app.route(f"/{GALLERY_PREFIX}/<path:path>", methods=["GET"])
  34. def send_gallery_video(path: str) -> Response:
  35. try:
  36. return send_from_directory(
  37. GALLERY_PATH,
  38. path,
  39. )
  40. except:
  41. raise ValueError("resource not found")
  42. @app.route(f"/{POSTERS_PREFIX}/<path:path>", methods=["GET"])
  43. def send_poster_image(path: str) -> Response:
  44. try:
  45. return send_from_directory(
  46. POSTERS_PATH,
  47. path,
  48. )
  49. except:
  50. raise ValueError("resource not found")
  51. @app.route(f"/{UPLOADS_PREFIX}/<path:path>", methods=["GET"])
  52. def send_uploaded_video(path: str):
  53. try:
  54. return send_from_directory(
  55. UPLOADS_PATH,
  56. path,
  57. )
  58. except:
  59. raise ValueError("resource not found")
  60. # TOOD: Protect route with ToS permission check
  61. @app.route("/propagate_in_video", methods=["POST"])
  62. def propagate_in_video() -> Response:
  63. data = request.json
  64. args = {
  65. "session_id": data["session_id"],
  66. "start_frame_index": data.get("start_frame_index", 0),
  67. }
  68. boundary = "frame"
  69. frame = gen_track_with_mask_stream(boundary, **args)
  70. return Response(frame, mimetype="multipart/x-savi-stream; boundary=" + boundary)
  71. def gen_track_with_mask_stream(
  72. boundary: str,
  73. session_id: str,
  74. start_frame_index: int,
  75. ) -> Generator[bytes, None, None]:
  76. with inference_api.autocast_context():
  77. request = PropagateInVideoRequest(
  78. type="propagate_in_video",
  79. session_id=session_id,
  80. start_frame_index=start_frame_index,
  81. )
  82. for chunk in inference_api.propagate_in_video(request=request):
  83. yield MultipartResponseBuilder.build(
  84. boundary=boundary,
  85. headers={
  86. "Content-Type": "application/json; charset=utf-8",
  87. "Frame-Current": "-1",
  88. # Total frames minus the reference frame
  89. "Frame-Total": "-1",
  90. "Mask-Type": "RLE[]",
  91. },
  92. body=chunk.to_json().encode("UTF-8"),
  93. ).get_message()
  94. class MyGraphQLView(GraphQLView):
  95. def get_context(self, request: Request, response: Response) -> Any:
  96. return {"inference_api": inference_api}
  97. # Add GraphQL route to Flask app.
  98. app.add_url_rule(
  99. "/graphql",
  100. view_func=MyGraphQLView.as_view(
  101. "graphql_view",
  102. schema=schema,
  103. # Disable GET queries
  104. # https://strawberry.rocks/docs/operations/deployment
  105. # https://strawberry.rocks/docs/integrations/flask
  106. allow_queries_via_get=False,
  107. # Strawberry recently changed multipart request handling, which now
  108. # requires enabling support explicitly for views.
  109. # https://github.com/strawberry-graphql/strawberry/issues/3655
  110. multipart_uploads_enabled=True,
  111. ),
  112. )
  113. if __name__ == "__main__":
  114. app.run(host="0.0.0.0", port=5000)