| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import hashlib
- import os
- import shutil
- import tempfile
- from pathlib import Path
- from typing import Iterable, List, Optional, Tuple, Union
- import av
- import strawberry
- from app_conf import (
- DATA_PATH,
- DEFAULT_VIDEO_PATH,
- MAX_UPLOAD_VIDEO_DURATION,
- UPLOADS_PATH,
- UPLOADS_PREFIX,
- )
- from data.data_types import (
- AddPointsInput,
- CancelPropagateInVideo,
- CancelPropagateInVideoInput,
- ClearPointsInFrameInput,
- ClearPointsInVideo,
- ClearPointsInVideoInput,
- CloseSession,
- CloseSessionInput,
- RemoveObjectInput,
- RLEMask,
- RLEMaskForObject,
- RLEMaskListOnFrame,
- StartSession,
- StartSessionInput,
- Video,
- )
- from data.loader import get_video
- from data.store import get_videos
- from data.transcoder import get_video_metadata, transcode, VideoMetadata
- from inference.data_types import (
- AddPointsRequest,
- CancelPropagateInVideoRequest,
- CancelPropagateInVideoRequest,
- ClearPointsInFrameRequest,
- ClearPointsInVideoRequest,
- CloseSessionRequest,
- RemoveObjectRequest,
- StartSessionRequest,
- )
- from inference.predictor import InferenceAPI
- from strawberry import relay
- from strawberry.file_uploads import Upload
- @strawberry.type
- class Query:
- @strawberry.field
- def default_video(self) -> Video:
- """
- Return the default video.
- The default video can be set with the DEFAULT_VIDEO_PATH environment
- variable. It will return the video that matches this path. If no video
- is found, it will return the first video.
- """
- all_videos = get_videos()
- # Find the video that matches the default path and return that as
- # default video.
- for _, v in all_videos.items():
- if v.path == DEFAULT_VIDEO_PATH:
- return v
- # Fallback is returning the first video
- return next(iter(all_videos.values()))
- @relay.connection(relay.ListConnection[Video])
- def videos(
- self,
- ) -> Iterable[Video]:
- """
- Return all available videos.
- """
- all_videos = get_videos()
- return all_videos.values()
- @strawberry.type
- class Mutation:
- @strawberry.mutation
- def upload_video(
- self,
- file: Upload,
- start_time_sec: Optional[float] = None,
- duration_time_sec: Optional[float] = None,
- ) -> Video:
- """
- Receive a video file and store it in the configured S3 bucket.
- """
- max_time = MAX_UPLOAD_VIDEO_DURATION
- filepath, file_key, vm = process_video(
- file,
- max_time=max_time,
- start_time_sec=start_time_sec,
- duration_time_sec=duration_time_sec,
- )
- video = get_video(
- filepath,
- UPLOADS_PATH,
- file_key=file_key,
- width=vm.width,
- height=vm.height,
- generate_poster=False,
- )
- return video
- @strawberry.mutation
- def start_session(
- self, input: StartSessionInput, info: strawberry.Info
- ) -> StartSession:
- inference_api: InferenceAPI = info.context["inference_api"]
- request = StartSessionRequest(
- type="start_session",
- path=f"{DATA_PATH}/{input.path}",
- )
- response = inference_api.start_session(request=request)
- return StartSession(session_id=response.session_id)
- @strawberry.mutation
- def close_session(
- self, input: CloseSessionInput, info: strawberry.Info
- ) -> CloseSession:
- inference_api: InferenceAPI = info.context["inference_api"]
- request = CloseSessionRequest(
- type="close_session",
- session_id=input.session_id,
- )
- response = inference_api.close_session(request)
- return CloseSession(success=response.success)
- @strawberry.mutation
- def add_points(
- self, input: AddPointsInput, info: strawberry.Info
- ) -> RLEMaskListOnFrame:
- inference_api: InferenceAPI = info.context["inference_api"]
- request = AddPointsRequest(
- type="add_points",
- session_id=input.session_id,
- frame_index=input.frame_index,
- object_id=input.object_id,
- points=input.points,
- labels=input.labels,
- clear_old_points=input.clear_old_points,
- )
- reponse = inference_api.add_points(request)
- return RLEMaskListOnFrame(
- frame_index=reponse.frame_index,
- rle_mask_list=[
- RLEMaskForObject(
- object_id=r.object_id,
- rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"),
- )
- for r in reponse.results
- ],
- )
- @strawberry.mutation
- def remove_object(
- self, input: RemoveObjectInput, info: strawberry.Info
- ) -> List[RLEMaskListOnFrame]:
- inference_api: InferenceAPI = info.context["inference_api"]
- request = RemoveObjectRequest(
- type="remove_object", session_id=input.session_id, object_id=input.object_id
- )
- response = inference_api.remove_object(request)
- return [
- RLEMaskListOnFrame(
- frame_index=res.frame_index,
- rle_mask_list=[
- RLEMaskForObject(
- object_id=r.object_id,
- rle_mask=RLEMask(
- counts=r.mask.counts, size=r.mask.size, order="F"
- ),
- )
- for r in res.results
- ],
- )
- for res in response.results
- ]
- @strawberry.mutation
- def clear_points_in_frame(
- self, input: ClearPointsInFrameInput, info: strawberry.Info
- ) -> RLEMaskListOnFrame:
- inference_api: InferenceAPI = info.context["inference_api"]
- request = ClearPointsInFrameRequest(
- type="clear_points_in_frame",
- session_id=input.session_id,
- frame_index=input.frame_index,
- object_id=input.object_id,
- )
- response = inference_api.clear_points_in_frame(request)
- return RLEMaskListOnFrame(
- frame_index=response.frame_index,
- rle_mask_list=[
- RLEMaskForObject(
- object_id=r.object_id,
- rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"),
- )
- for r in response.results
- ],
- )
- @strawberry.mutation
- def clear_points_in_video(
- self, input: ClearPointsInVideoInput, info: strawberry.Info
- ) -> ClearPointsInVideo:
- inference_api: InferenceAPI = info.context["inference_api"]
- request = ClearPointsInVideoRequest(
- type="clear_points_in_video",
- session_id=input.session_id,
- )
- response = inference_api.clear_points_in_video(request)
- return ClearPointsInVideo(success=response.success)
- @strawberry.mutation
- def cancel_propagate_in_video(
- self, input: CancelPropagateInVideoInput, info: strawberry.Info
- ) -> CancelPropagateInVideo:
- inference_api: InferenceAPI = info.context["inference_api"]
- request = CancelPropagateInVideoRequest(
- type="cancel_propagate_in_video",
- session_id=input.session_id,
- )
- response = inference_api.cancel_propagate_in_video(request)
- return CancelPropagateInVideo(success=response.success)
- def get_file_hash(video_path_or_file) -> str:
- if isinstance(video_path_or_file, str):
- with open(video_path_or_file, "rb") as in_f:
- result = hashlib.sha256(in_f.read()).hexdigest()
- else:
- video_path_or_file.seek(0)
- result = hashlib.sha256(video_path_or_file.read()).hexdigest()
- return result
- def _get_start_sec_duration_sec(
- start_time_sec: Union[float, None],
- duration_time_sec: Union[float, None],
- max_time: float,
- ) -> Tuple[float, float]:
- default_seek_t = int(os.environ.get("VIDEO_ENCODE_SEEK_TIME", "0"))
- if start_time_sec is None:
- start_time_sec = default_seek_t
- if duration_time_sec is not None:
- duration_time_sec = min(duration_time_sec, max_time)
- else:
- duration_time_sec = max_time
- return start_time_sec, duration_time_sec
- def process_video(
- file: Upload,
- max_time: float,
- start_time_sec: Optional[float] = None,
- duration_time_sec: Optional[float] = None,
- ) -> Tuple[Optional[str], str, str, VideoMetadata]:
- """
- Process file upload including video trimming and content moderation checks.
- Returns the filepath, s3_file_key, hash & video metaedata as a tuple.
- """
- with tempfile.TemporaryDirectory() as tempdir:
- in_path = f"{tempdir}/in.mp4"
- out_path = f"{tempdir}/out.mp4"
- with open(in_path, "wb") as in_f:
- in_f.write(file.read())
- try:
- video_metadata = get_video_metadata(in_path)
- except av.InvalidDataError:
- raise Exception("not valid video file")
- if video_metadata.num_video_streams == 0:
- raise Exception("video container does not contain a video stream")
- if video_metadata.width is None or video_metadata.height is None:
- raise Exception("video container does not contain width or height metadata")
- if video_metadata.duration_sec in (None, 0):
- raise Exception("video container does time duration metadata")
- start_time_sec, duration_time_sec = _get_start_sec_duration_sec(
- max_time=max_time,
- start_time_sec=start_time_sec,
- duration_time_sec=duration_time_sec,
- )
- # Transcode video to make sure videos returned to the app are all in
- # the same format, duration, resolution, fps.
- transcode(
- in_path,
- out_path,
- video_metadata,
- seek_t=start_time_sec,
- duration_time_sec=duration_time_sec,
- )
- os.remove(in_path) # don't need original video now
- out_video_metadata = get_video_metadata(out_path)
- if out_video_metadata.num_video_frames == 0:
- raise Exception(
- "transcode produced empty video; check seek time or your input video"
- )
- filepath = None
- file_key = None
- with open(out_path, "rb") as file_data:
- file_hash = get_file_hash(file_data)
- file_data.seek(0)
- file_key = UPLOADS_PREFIX + "/" + f"{file_hash}.mp4"
- filepath = os.path.join(UPLOADS_PATH, f"{file_hash}.mp4")
- assert filepath is not None and file_key is not None
- shutil.move(out_path, filepath)
- return filepath, file_key, out_video_metadata
- schema = strawberry.Schema(
- query=Query,
- mutation=Mutation,
- )
|