diff --git a/configs/detection/ava/slowfast_kinetics400-pretrained-r50_8xb6-8x8x1-cosine-10e_ava22-rgb.py b/configs/detection/ava/slowfast_kinetics400-pretrained-r50_8xb6-8x8x1-cosine-10e_ava22-rgb.py index 3cbf483e57..a7f4c09ed1 100644 --- a/configs/detection/ava/slowfast_kinetics400-pretrained-r50_8xb6-8x8x1-cosine-10e_ava22-rgb.py +++ b/configs/detection/ava/slowfast_kinetics400-pretrained-r50_8xb6-8x8x1-cosine-10e_ava22-rgb.py @@ -1,4 +1,4 @@ -_base_ = ['slowfast_kinetics400-pretrained-r50_8xb16-8x8x1-20e_ava21-rgb.py'] +_base_ = ['slowfast_kinetics400-pretrained-r50_8xb8-8x8x1-20e_ava21-rgb.py'] model = dict( backbone=dict( diff --git a/configs/detection/ava/slowfast_kinetics400-pretrained-r50_8xb8-8x8x1-20e_ava21-rgb.py b/configs/detection/ava/slowfast_kinetics400-pretrained-r50_8xb8-8x8x1-20e_ava21-rgb.py index c1f8b4d6e0..97e0197a6e 100644 --- a/configs/detection/ava/slowfast_kinetics400-pretrained-r50_8xb8-8x8x1-20e_ava21-rgb.py +++ b/configs/detection/ava/slowfast_kinetics400-pretrained-r50_8xb8-8x8x1-20e_ava21-rgb.py @@ -1,11 +1,11 @@ -_base_ = ['slowonly_kinetics400-pretrained-r50_8xb16-4x16x1-20e_ava21-rgb.py'] +_base_ = ['slowfast_kinetics400-pretrained-r50_8xb16-4x16x1-20e_ava21-rgb.py'] model = dict( backbone=dict( resample_rate=4, speed_ratio=4, slow_pathway=dict(fusion_kernel=7), - prtrained=( + pretrained=( 'https://download.openmmlab.com/mmaction/recognition/slowfast/' 'slowfast_r50_8x8x1_256e_kinetics400_rgb/' 'slowfast_r50_8x8x1_256e_kinetics400_rgb_20200716-73547d2b.pth'))) diff --git a/demo/demo_spatiotemporal_det.mp4 b/demo/demo_spatiotemporal_det.mp4 new file mode 100644 index 0000000000..007e1d7f25 Binary files /dev/null and b/demo/demo_spatiotemporal_det.mp4 differ diff --git a/demo/demo_spatiotemporal_det.py b/demo/demo_spatiotemporal_det.py new file mode 100644 index 0000000000..a8e49c1020 --- /dev/null +++ b/demo/demo_spatiotemporal_det.py @@ -0,0 +1,402 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy as cp +import os +import os.path as osp +import shutil + +import cv2 +import mmcv +import mmengine +import numpy as np +import torch +from mmengine import DictAction +from mmengine.runner import load_checkpoint +from mmengine.structures import InstanceData + +from mmaction.apis import detection_inference +from mmaction.registry import MODELS +from mmaction.structures import ActionDataSample +from mmaction.utils import register_all_modules + +try: + import moviepy.editor as mpy +except ImportError: + raise ImportError('Please install moviepy to enable output file') + +FONTFACE = cv2.FONT_HERSHEY_DUPLEX +FONTSCALE = 0.5 +FONTCOLOR = (255, 255, 255) # BGR, white +MSGCOLOR = (128, 128, 128) # BGR, gray +THICKNESS = 1 +LINETYPE = 1 + + +def hex2color(h): + """Convert the 6-digit hex string to tuple of 3 int value (RGB)""" + return (int(h[:2], 16), int(h[2:4], 16), int(h[4:], 16)) + + +plate_blue = '03045e-023e8a-0077b6-0096c7-00b4d8-48cae4' +plate_blue = plate_blue.split('-') +plate_blue = [hex2color(h) for h in plate_blue] +plate_green = '004b23-006400-007200-008000-38b000-70e000' +plate_green = plate_green.split('-') +plate_green = [hex2color(h) for h in plate_green] + + +def visualize(frames, annotations, plate=plate_blue, max_num=5): + """Visualize frames with predicted annotations. + + Args: + frames (list[np.ndarray]): Frames for visualization, note that + len(frames) % len(annotations) should be 0. + annotations (list[list[tuple]]): The predicted results. + plate (str): The plate used for visualization. Default: plate_blue. + max_num (int): Max number of labels to visualize for a person box. + Default: 5. + Returns: + list[np.ndarray]: Visualized frames. + """ + + assert max_num + 1 <= len(plate) + plate = [x[::-1] for x in plate] + frames_out = cp.deepcopy(frames) + nf, na = len(frames), len(annotations) + assert nf % na == 0 + nfpa = len(frames) // len(annotations) + anno = None + h, w, _ = frames[0].shape + scale_ratio = np.array([w, h, w, h]) + for i in range(na): + anno = annotations[i] + if anno is None: + continue + for j in range(nfpa): + ind = i * nfpa + j + frame = frames_out[ind] + for ann in anno: + box = ann[0] + label = ann[1] + if not len(label): + continue + score = ann[2] + box = (box * scale_ratio).astype(np.int64) + st, ed = tuple(box[:2]), tuple(box[2:]) + cv2.rectangle(frame, st, ed, plate[0], 2) + for k, lb in enumerate(label): + if k >= max_num: + break + text = abbrev(lb) + text = ': '.join([text, str(score[k])]) + location = (0 + st[0], 18 + k * 18 + st[1]) + textsize = cv2.getTextSize(text, FONTFACE, FONTSCALE, + THICKNESS)[0] + textwidth = textsize[0] + diag0 = (location[0] + textwidth, location[1] - 14) + diag1 = (location[0], location[1] + 2) + cv2.rectangle(frame, diag0, diag1, plate[k + 1], -1) + cv2.putText(frame, text, location, FONTFACE, FONTSCALE, + FONTCOLOR, THICKNESS, LINETYPE) + + return frames_out + + +def frame_extraction(video_path): + """Extract frames given video_path. + + Args: + video_path (str): The video_path. + """ + # Load the video, extract frames into ./tmp/video_name + target_dir = osp.join('./tmp', osp.basename(osp.splitext(video_path)[0])) + os.makedirs(target_dir, exist_ok=True) + # Should be able to handle videos up to several hours + frame_tmpl = osp.join(target_dir, 'img_{:06d}.jpg') + vid = cv2.VideoCapture(video_path) + frames = [] + frame_paths = [] + flag, frame = vid.read() + cnt = 0 + while flag: + frames.append(frame) + frame_path = frame_tmpl.format(cnt + 1) + frame_paths.append(frame_path) + cv2.imwrite(frame_path, frame) + cnt += 1 + flag, frame = vid.read() + return frame_paths, frames + + +def load_label_map(file_path): + """Load Label Map. + + Args: + file_path (str): The file path of label map. + Returns: + dict: The label map (int -> label name). + """ + lines = open(file_path).readlines() + lines = [x.strip().split(': ') for x in lines] + return {int(x[0]): x[1] for x in lines} + + +def abbrev(name): + """Get the abbreviation of label name: + + 'take (an object) from (a person)' -> 'take ... from ...' + """ + while name.find('(') != -1: + st, ed = name.find('('), name.find(')') + name = name[:st] + '...' + name[ed + 1:] + return name + + +def pack_result(human_detection, result, img_h, img_w): + """Short summary. + + Args: + human_detection (np.ndarray): Human detection result. + result (type): The predicted label of each human proposal. + img_h (int): The image height. + img_w (int): The image width. + Returns: + tuple: Tuple of human proposal, label name and label score. + """ + human_detection[:, 0::2] /= img_w + human_detection[:, 1::2] /= img_h + results = [] + if result is None: + return None + for prop, res in zip(human_detection, result): + res.sort(key=lambda x: -x[1]) + results.append( + (prop.data.cpu().numpy(), [x[0] for x in res], [x[1] + for x in res])) + return results + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMAction2 demo') + parser.add_argument('video', help='video file/url') + parser.add_argument( + 'out_filename', help='output filename', default='demo/stdet_demo.mp4') + parser.add_argument( + '--config', + default=('configs/detection/ava/slowonly_kinetics400-pretrained-' + 'r101_8xb16-8x8x1-20e_ava21-rgb.py'), + help='spatialtemporal detection model config file path') + parser.add_argument( + '--checkpoint', + default=('https://download.openmmlab.com/mmaction/detection/ava/' + 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/' + 'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_' + '20201217-16378594.pth'), + help='spatialtemporal detection model checkpoint file/url') + parser.add_argument( + '--det-config', + default='demo/skeleton_demo_cfg/faster-rcnn_r50_fpn_2x_coco_infer.py', + help='human detection config file path (from mmdet)') + parser.add_argument( + '--det-checkpoint', + default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/' + 'faster_rcnn_r50_fpn_2x_coco/' + 'faster_rcnn_r50_fpn_2x_coco_' + 'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'), + help='human detection checkpoint file/url') + parser.add_argument( + '--det-score-thr', + type=float, + default=0.9, + help='the threshold of human detection score') + parser.add_argument( + '--det-cat-id', + type=int, + default=0, + help='the category id for human detection') + parser.add_argument( + '--action-score-thr', + type=float, + default=0.5, + help='the threshold of human action score') + parser.add_argument( + '--label-map', + default='tools/data/ava/label_map.txt', + help='label map file') + parser.add_argument( + '--device', type=str, default='cuda:0', help='CPU/CUDA device option') + parser.add_argument( + '--short-side', + type=int, + default=256, + help='specify the short-side length of the image') + parser.add_argument( + '--predict-stepsize', + default=8, + type=int, + help='give out a prediction per n frames') + parser.add_argument( + '--output-stepsize', + default=4, + type=int, + help=('show one frame per n frames in the demo, we should have: ' + 'predict_stepsize % output_stepsize == 0')) + parser.add_argument( + '--output-fps', + default=6, + type=int, + help='the fps of demo video output') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. For example, ' + "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + register_all_modules() + + frame_paths, original_frames = frame_extraction(args.video) + num_frame = len(frame_paths) + h, w, _ = original_frames[0].shape + + # resize frames to shortside + new_w, new_h = mmcv.rescale_size((w, h), (args.short_side, np.Inf)) + frames = [mmcv.imresize(img, (new_w, new_h)) for img in original_frames] + w_ratio, h_ratio = new_w / w, new_h / h + + # Get clip_len, frame_interval and calculate center index of each clip + config = mmengine.Config.fromfile(args.config) + config.merge_from_dict(args.cfg_options) + val_pipeline = config.val_pipeline + + sampler = [x for x in val_pipeline if x['type'] == 'SampleAVAFrames'][0] + clip_len, frame_interval = sampler['clip_len'], sampler['frame_interval'] + window_size = clip_len * frame_interval + assert clip_len % 2 == 0, 'We would like to have an even clip_len' + # Note that it's 1 based here + timestamps = np.arange(window_size // 2, num_frame + 1 - window_size // 2, + args.predict_stepsize) + + # Load label_map + label_map = load_label_map(args.label_map) + try: + if config['data']['train']['custom_classes'] is not None: + label_map = { + id + 1: label_map[cls] + for id, cls in enumerate(config['data']['train'] + ['custom_classes']) + } + except KeyError: + pass + + # Get Human detection results + center_frames = [frame_paths[ind - 1] for ind in timestamps] + + human_detections, _ = detection_inference(args.det_config, + args.det_checkpoint, + center_frames, + args.det_score_thr, + args.det_cat_id, args.device) + torch.cuda.empty_cache() + for i in range(len(human_detections)): + det = human_detections[i] + det[:, 0:4:2] *= w_ratio + det[:, 1:4:2] *= h_ratio + human_detections[i] = torch.from_numpy(det[:, :4]).to(args.device) + + # Build STDET model + try: + # In our spatiotemporal detection demo, different actions should have + # the same number of bboxes. + config['model']['test_cfg']['rcnn'] = dict(action_thr=0) + except KeyError: + pass + + config.model.backbone.pretrained = None + model = MODELS.build(config.model) + + load_checkpoint(model, args.checkpoint, map_location='cpu') + model.to(args.device) + model.eval() + + predictions = [] + + img_norm_cfg = dict( + mean=np.array(config.model.data_preprocessor.mean), + std=np.array(config.model.data_preprocessor.std), + to_rgb=False) + + print('Performing SpatioTemporal Action Detection for each clip') + assert len(timestamps) == len(human_detections) + prog_bar = mmengine.ProgressBar(len(timestamps)) + for timestamp, proposal in zip(timestamps, human_detections): + if proposal.shape[0] == 0: + predictions.append(None) + continue + + start_frame = timestamp - (clip_len // 2 - 1) * frame_interval + frame_inds = start_frame + np.arange(0, window_size, frame_interval) + frame_inds = list(frame_inds - 1) + imgs = [frames[ind].astype(np.float32) for ind in frame_inds] + _ = [mmcv.imnormalize_(img, **img_norm_cfg) for img in imgs] + # THWC -> CTHW -> 1CTHW + input_array = np.stack(imgs).transpose((3, 0, 1, 2))[np.newaxis] + input_tensor = torch.from_numpy(input_array).to(args.device) + + datasample = ActionDataSample() + datasample.proposals = InstanceData(bboxes=proposal) + datasample.set_metainfo(dict(img_shape=(new_h, new_w))) + with torch.no_grad(): + result = model(input_tensor, [datasample], mode='predict') + scores = result[0].pred_instances.scores + prediction = [] + # N proposals + for i in range(proposal.shape[0]): + prediction.append([]) + # Perform action score thr + for i in range(scores.shape[1]): + if i not in label_map: + continue + for j in range(proposal.shape[0]): + if scores[j, i] > args.action_score_thr: + prediction[j].append((label_map[i], scores[j, + i].item())) + predictions.append(prediction) + prog_bar.update() + + results = [] + for human_detection, prediction in zip(human_detections, predictions): + results.append(pack_result(human_detection, prediction, new_h, new_w)) + + def dense_timestamps(timestamps, n): + """Make it nx frames.""" + old_frame_interval = (timestamps[1] - timestamps[0]) + start = timestamps[0] - old_frame_interval / n * (n - 1) / 2 + new_frame_inds = np.arange( + len(timestamps) * n) * old_frame_interval / n + start + return new_frame_inds.astype(np.int) + + dense_n = int(args.predict_stepsize / args.output_stepsize) + frames = [ + cv2.imread(frame_paths[i - 1]) + for i in dense_timestamps(timestamps, dense_n) + ] + print('Performing visualization') + vis_frames = visualize(frames, results) + vid = mpy.ImageSequenceClip([x[:, :, ::-1] for x in vis_frames], + fps=args.output_fps) + vid.write_videofile(args.out_filename) + + tmp_frame_dir = osp.dirname(frame_paths[0]) + shutil.rmtree(tmp_frame_dir) + + +if __name__ == '__main__': + main() diff --git a/docs/en/user_guides/3_inference.md b/docs/en/user_guides/3_inference.md index 8a603b66ee..06d6f55537 100644 --- a/docs/en/user_guides/3_inference.md +++ b/docs/en/user_guides/3_inference.md @@ -117,3 +117,62 @@ python demo/demo_skeleton.py demo/demo_skeleton.mp4 demo/demo_skeleton_out.mp4 \ --pose-checkpoint https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth \ --label-map tools/data/skeleton/label_map_ntu60.txt ``` + +## SpatioTemporal Action Detection Video Demo + +We provide a demo script to predict the SpatioTemporal Action Detection result using a single video. + +```shell +python demo/demo_spatiotemporal_det.py --video ${VIDEO_FILE} \ + [--out-filename ${OUTPUT_FILENAME}] \ + [--config ${SPATIOTEMPORAL_ACTION_DETECTION_CONFIG_FILE}] \ + [--checkpoint ${SPATIOTEMPORAL_ACTION_DETECTION_CHECKPOINT}] \ + [--det-config ${HUMAN_DETECTION_CONFIG_FILE}] \ + [--det-checkpoint ${HUMAN_DETECTION_CHECKPOINT}] \ + [--det-score-thr ${HUMAN_DETECTION_SCORE_THRESHOLD}] \ + [--det-cat-id ${HUMAN_DETECTION_CATEGORY_ID}] \ + [--action-score-thr ${ACTION_DETECTION_SCORE_THRESHOLD}] \ + [--label-map ${LABEL_MAP}] \ + [--device ${DEVICE}] \ + [--short-side] ${SHORT_SIDE} \ + [--predict-stepsize ${PREDICT_STEPSIZE}] \ + [--output-stepsize ${OUTPUT_STEPSIZE}] \ + [--output-fps ${OUTPUT_FPS}] +``` + +Optional arguments: + +- `OUTPUT_FILENAME`: Path to the output file which is a video format. Defaults to `demo/stdet_demo.mp4`. +- `SPATIOTEMPORAL_ACTION_DETECTION_CONFIG_FILE`: The spatiotemporal action detection config file path. +- `SPATIOTEMPORAL_ACTION_DETECTION_CHECKPOINT`: The spatiotemporal action detection checkpoint URL. +- `HUMAN_DETECTION_CONFIG_FILE`: The human detection config file path. +- `HUMAN_DETECTION_CHECKPOINT`: The human detection checkpoint URL. +- `HUMAN_DETECTION_SCORE_THRESHOLD`: The score threshold for human detection. Defaults to 0.9. +- `HUMAN_DETECTION_CATEGORY_ID`: The category id for human detection. Defaults to 0. +- `ACTION_DETECTION_SCORE_THRESHOLD`: The score threshold for action detection. Defaults to 0.5. +- `LABEL_MAP`: The label map used. Defaults to `tools/data/ava/label_map.txt`. +- `DEVICE`: Type of device to run the demo. Allowed values are cuda device like `cuda:0` or `cpu`. Defaults to `cuda:0`. +- `SHORT_SIDE`: The short side used for frame extraction. Defaults to 256. +- `PREDICT_STEPSIZE`: Make a prediction per N frames. Defaults to 8. +- `OUTPUT_STEPSIZE`: Output 1 frame per N frames in the input video. Note that `PREDICT_STEPSIZE % OUTPUT_STEPSIZE == 0`. Defaults to 4. +- `OUTPUT_FPS`: The FPS of demo video output. Defaults to 6. + +Examples: + +Assume that you are located at `$MMACTION2` . + +1. Use the Faster RCNN as the human detector, SlowOnly-8x8-R101 as the action detector. Making predictions per 8 frames, and output 1 frame per 4 frames to the output video. The FPS of the output video is 4. + +```shell +python demo/demo_spatiotemporal_det.py demo/demo.mp4 demo/demo_spatiotemporal_det.mp4 \ + --config configs/detection/ava/slowonly_kinetics400-pretrained-r101_8xb16-8x8x1-20e_ava21-rgb.py \ + --checkpoint https://download.openmmlab.com/mmaction/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb_20201217-16378594.pth \ + --det-config demo/skeleton_demo_cfg/faster-rcnn_r50_fpn_2x_coco_infer.py \ + --det-checkpoint http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth \ + --det-score-thr 0.9 \ + --action-score-thr 0.5 \ + --label-map tools/data/ava/label_map.txt \ + --predict-stepsize 8 \ + --output-stepsize 4 \ + --output-fps 6 +``` diff --git a/mmaction/apis/inference.py b/mmaction/apis/inference.py index 0e3ab2ddac..67d7d6bce2 100644 --- a/mmaction/apis/inference.py +++ b/mmaction/apis/inference.py @@ -127,7 +127,8 @@ def detection_inference(det_config: Union[str, Path, mmengine.Config], '`init_detector` from `mmdet.apis`. These apis are ' 'required in this inference api! ') - model = init_detector(det_config, det_checkpoint, device) + model = init_detector( + config=det_config, checkpoint=det_checkpoint, device=device) results = [] data_samples = []