| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- from typing import Any, Generator
- from app_conf import (
- GALLERY_PATH,
- GALLERY_PREFIX,
- POSTERS_PATH,
- POSTERS_PREFIX,
- UPLOADS_PATH,
- UPLOADS_PREFIX,
- )
- from data.loader import preload_data
- from data.schema import schema
- from data.store import set_videos
- from flask import Flask, make_response, Request, request, Response, send_from_directory
- from flask_cors import CORS
- from inference.data_types import PropagateDataResponse, PropagateInVideoRequest
- from inference.multipart import MultipartResponseBuilder
- from inference.predictor import InferenceAPI
- from strawberry.flask.views import GraphQLView
- logger = logging.getLogger(__name__)
- app = Flask(__name__)
- cors = CORS(app, supports_credentials=True)
- videos = preload_data()
- set_videos(videos)
- inference_api = InferenceAPI()
- @app.route("/healthy")
- def healthy() -> Response:
- return make_response("OK", 200)
- @app.route(f"/{GALLERY_PREFIX}/<path:path>", methods=["GET"])
- def send_gallery_video(path: str) -> Response:
- try:
- return send_from_directory(
- GALLERY_PATH,
- path,
- )
- except:
- raise ValueError("resource not found")
- @app.route(f"/{POSTERS_PREFIX}/<path:path>", methods=["GET"])
- def send_poster_image(path: str) -> Response:
- try:
- return send_from_directory(
- POSTERS_PATH,
- path,
- )
- except:
- raise ValueError("resource not found")
- @app.route(f"/{UPLOADS_PREFIX}/<path:path>", methods=["GET"])
- def send_uploaded_video(path: str):
- try:
- return send_from_directory(
- UPLOADS_PATH,
- path,
- )
- except:
- raise ValueError("resource not found")
- # TOOD: Protect route with ToS permission check
- @app.route("/propagate_in_video", methods=["POST"])
- def propagate_in_video() -> Response:
- data = request.json
- args = {
- "session_id": data["session_id"],
- "start_frame_index": data.get("start_frame_index", 0),
- }
- boundary = "frame"
- frame = gen_track_with_mask_stream(boundary, **args)
- return Response(frame, mimetype="multipart/x-savi-stream; boundary=" + boundary)
- def gen_track_with_mask_stream(
- boundary: str,
- session_id: str,
- start_frame_index: int,
- ) -> Generator[bytes, None, None]:
- with inference_api.autocast_context():
- request = PropagateInVideoRequest(
- type="propagate_in_video",
- session_id=session_id,
- start_frame_index=start_frame_index,
- )
- for chunk in inference_api.propagate_in_video(request=request):
- yield MultipartResponseBuilder.build(
- boundary=boundary,
- headers={
- "Content-Type": "application/json; charset=utf-8",
- "Frame-Current": "-1",
- # Total frames minus the reference frame
- "Frame-Total": "-1",
- "Mask-Type": "RLE[]",
- },
- body=chunk.to_json().encode("UTF-8"),
- ).get_message()
- class MyGraphQLView(GraphQLView):
- def get_context(self, request: Request, response: Response) -> Any:
- return {"inference_api": inference_api}
- # Add GraphQL route to Flask app.
- app.add_url_rule(
- "/graphql",
- view_func=MyGraphQLView.as_view(
- "graphql_view",
- schema=schema,
- # Disable GET queries
- # https://strawberry.rocks/docs/operations/deployment
- # https://strawberry.rocks/docs/integrations/flask
- allow_queries_via_get=False,
- # Strawberry recently changed multipart request handling, which now
- # requires enabling support explicitly for views.
- # https://github.com/strawberry-graphql/strawberry/issues/3655
- multipart_uploads_enabled=True,
- ),
- )
- if __name__ == "__main__":
- app.run(host="0.0.0.0", port=5000)
|