mirror of
https://github.com/TelegramMessenger/tgcalls.git
synced 2026-05-21 18:20:42 +00:00
feat: add testbench (CLI tool, Go SFU, Dockerfile, docs)
Brings the testbench source into the submodule: - tools/cli/ — C++ CLI test tool (P2P, reflector, group, group-churn modes) - tools/go_sfu/ — Go/Pion SFU library, c-archive linked into tgcalls_cli - Dockerfile — multi-stage Linux container build - CLAUDE.md (top-level), tools/cli/CLAUDE.md, tools/go_sfu/CLAUDE.md — docs Bazel build glue (.bazelrc, MODULE.bazel, third-party BUILD edits, tgcalls_core target) remains in the outer repo since the dependency stack lives there; labels and paths in this repo reference the outer-repo workspace root. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,202 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This is a testbench repository for the tgcalls VoIP library (from Telegram). It contains the full Telegram iOS source tree as a build dependency, but the focus is on testing and debugging tgcalls.
|
||||
|
||||
## Build
|
||||
|
||||
Requires Bazel 8.4.2 (download to `build-input/` if not present):
|
||||
|
||||
```bash
|
||||
# One-time setup: create build configuration stub
|
||||
mkdir -p build-input/configuration-repository/provisioning
|
||||
# Then populate MODULE.bazel, BUILD, variables.bzl, provisioning/BUILD
|
||||
# (see build-input/configuration-repository/ for existing stubs)
|
||||
|
||||
# Build the CLI test tool
|
||||
./build-input/bazel-8.4.2 build //submodules/TgVoipWebrtc/tgcalls/tools/cli:tgcalls_cli
|
||||
```
|
||||
|
||||
The system-installed Bazel (v9) is NOT compatible with this codebase.
|
||||
|
||||
## Linux Build
|
||||
|
||||
Prerequisites (Ubuntu/Debian):
|
||||
```bash
|
||||
apt install gcc g++ cmake meson ninja-build nasm make autoconf automake libtool pkg-config zlib1g-dev libbz2-dev
|
||||
```
|
||||
|
||||
Download the Linux Bazel 8.4.2 binary to `build-input/`:
|
||||
```bash
|
||||
curl -fL "https://github.com/bazelbuild/bazel/releases/download/8.4.2/bazel-8.4.2-linux-arm64" -o build-input/bazel-8.4.2-linux
|
||||
chmod +x build-input/bazel-8.4.2-linux
|
||||
```
|
||||
|
||||
Build the CLI test tool:
|
||||
```bash
|
||||
./build-input/bazel-8.4.2-linux build //submodules/TgVoipWebrtc/tgcalls/tools/cli:tgcalls_cli
|
||||
```
|
||||
|
||||
The same Bazel 8.4.2 version is required. The build uses the system GCC toolchain and system-installed cmake/meson/ninja for third-party library compilation.
|
||||
|
||||
## Docker Build
|
||||
|
||||
Build a minimal Linux container image from macOS (or any Docker host):
|
||||
|
||||
```bash
|
||||
# Build (uses BuildKit cache — first build ~5 min, rebuilds seconds)
|
||||
docker build -t tgcalls-test .
|
||||
|
||||
# Run locally
|
||||
docker run --rm tgcalls-test --mode p2p --duration 5 --quiet
|
||||
docker run --rm tgcalls-test --mode reflector --reflector 91.108.13.2:598 --duration 10 --quiet
|
||||
|
||||
# Push to ECR for AWS deployment
|
||||
docker tag tgcalls-test 654654616143.dkr.ecr.eu-west-1.amazonaws.com/tgcalls-test:latest
|
||||
docker push 654654616143.dkr.ecr.eu-west-1.amazonaws.com/tgcalls-test:latest
|
||||
```
|
||||
|
||||
The Dockerfile uses a multi-stage build: full build environment in stage 1, minimal runtime image (~50MB) in stage 2. Bazel's build cache is preserved across `docker build` invocations via `--mount=type=cache`. The image is built for ARM64 (matches Apple Silicon and Fargate ARM).
|
||||
|
||||
## Testing
|
||||
|
||||
### Local Mass Testing
|
||||
|
||||
Run large-scale P2P tests locally using `run-local-test.sh`. Launches N parallel processes, each running a single call, and aggregates results.
|
||||
|
||||
```bash
|
||||
# 1000 calls, 150 parallel, 30% loss (default settings)
|
||||
./submodules/TgVoipWebrtc/tgcalls/tools/cli/run-local-test.sh -n 1000
|
||||
|
||||
# Custom parallelism and duration
|
||||
./submodules/TgVoipWebrtc/tgcalls/tools/cli/run-local-test.sh -n 500 -j 100 -d 30
|
||||
|
||||
# Custom loss parameters
|
||||
./submodules/TgVoipWebrtc/tgcalls/tools/cli/run-local-test.sh -n 1000 --drop-rate 0.5 --delay 100-300
|
||||
```
|
||||
|
||||
Options: `-n NUM` (count), `-j PARALLEL` (default 150), `-d DURATION` (default 15s), `--drop-rate RATE` (default 0.3), `--delay MIN-MAX` (default 50-200), `--mode MODE` (default p2p), `--version VER` (default 13.0.0).
|
||||
|
||||
Typical results: 100% success rate at 30% loss on Apple Silicon (16 cores).
|
||||
|
||||
### AWS Mass Testing
|
||||
|
||||
Run large-scale reflector tests on ECS Fargate (ARM64). Infrastructure is pre-configured in eu-west-1. Requires Docker push first.
|
||||
|
||||
```bash
|
||||
# Launch 1000 tasks across all Telegram reflectors, 30s each
|
||||
./submodules/TgVoipWebrtc/tgcalls/tools/cli/run-test.sh -n 1000 -d 30
|
||||
|
||||
# Collect results
|
||||
./submodules/TgVoipWebrtc/tgcalls/tools/cli/run-test.sh --results
|
||||
```
|
||||
|
||||
The script fetches the reflector list from `https://core.telegram.org/getReflectorList`, embeds the IPs as a `--reflector-list` argument (each task picks a random IP + random port 596-599), and launches in waves of 500 (Fargate concurrent task limit). Results are collected from CloudWatch Logs with automatic retry for delayed log delivery.
|
||||
|
||||
**AWS resources** (eu-west-1, account 654654616143):
|
||||
- ECR: `tgcalls-test`
|
||||
- ECS cluster: `tgcalls-test`
|
||||
- Task definition: `tgcalls-test` (ARM64 Fargate, 0.25 vCPU, 512MB)
|
||||
- CloudWatch log group: `/ecs/tgcalls-test`
|
||||
- Subnets: `subnet-0292f49f3b4885428`, `subnet-09b8edab6eb20b837`, `subnet-0f464b5c62c9a6d1a`
|
||||
- Security group: `sg-0d87a1f19be76c160`
|
||||
|
||||
**Cost**: ~$0.01 per 100 tasks (~$0.10 per 1000-task run).
|
||||
|
||||
## tgcalls CLI Test Tool
|
||||
|
||||
Located at `submodules/TgVoipWebrtc/tgcalls/tools/cli/`. Runs tgcalls instances in-process with emulated signaling and validates audio/media flow.
|
||||
|
||||
```bash
|
||||
# P2P mode (direct loopback, no network)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode p2p --duration 10
|
||||
|
||||
# Reflector mode (routes through a real Telegram UDP reflector)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode reflector --reflector 91.108.13.2:596 --duration 10
|
||||
|
||||
# Random reflector from a list (picks one at random, randomizes port 596-599 if no port given)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --reflector-list "91.108.13.2,91.108.13.3,91.108.9.1" --duration 10
|
||||
|
||||
# Simulate lossy signaling (30% drop, 50-200ms random delay)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode p2p --duration 30 --drop-rate 0.3 --delay 50-200
|
||||
|
||||
# Quiet mode (summary only, full tgcalls logs dumped on failure)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode p2p --duration 5 --quiet
|
||||
|
||||
# Group mode (in-process SFU with N participants)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 3 --duration 10
|
||||
|
||||
# Mixed group mode (CustomImpl + ReferenceImpl participants)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 2 --reference-participants 2 --duration 15
|
||||
|
||||
# Group mode with video (H264 simulcast, pattern generator)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 2 --video --duration 15
|
||||
|
||||
# Mixed group with video (both CustomImpl and ReferenceImpl send/receive video)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 2 --reference-participants 2 --video --duration 15
|
||||
|
||||
# ReferenceImpl-only video
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 0 --reference-participants 3 --video --duration 15
|
||||
|
||||
# Group churn stress test (100 join/leave cycles, then validate base group)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group-churn --participants 3 --duration 10
|
||||
|
||||
# Group churn with video
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group-churn --participants 3 --video --churn-cycles 100 --duration 10
|
||||
|
||||
# Mixed implementations churn
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group-churn --participants 2 --reference-participants 1 --video --duration 10
|
||||
```
|
||||
|
||||
`--mode` is required (`p2p`, `reflector`, `group`, or `group-churn`) unless `--reflector-list` is used (implies reflector mode). Exit code 0 = success. Exit code 1 = failure.
|
||||
|
||||
For p2p/reflector: success = call established, stats logs non-empty, BWE non-zero for both sides.
|
||||
For group (audio): success = all N participants report `isConnected = true` AND all participants receive remote audio (non-zero SSRC with level > 0.05 via `audioLevelsUpdated`). Remote 440Hz sine tone typically arrives at ~0.126 level.
|
||||
For group (video): audio criteria plus every participant receives ≥1 decoded video frame from every other participant via `FakeVideoSink` frame counting.
|
||||
For group-churn: success = all churn cycles complete without crash/hang AND base group passes group validation (all connected, all receiving audio, and if video, all receiving video from all other base participants).
|
||||
|
||||
### CLI Options
|
||||
- `--mode p2p|reflector|group|group-churn` — call mode (required unless `--reflector-list` used)
|
||||
- `--reflector host:port` — single reflector address
|
||||
- `--reflector-list addr,addr,...` — comma-separated list, one picked at random
|
||||
- `--version VER` — caller tgcalls protocol version (default: `13.0.0`)
|
||||
- `--version2 VER` — callee tgcalls protocol version (default: same as `--version`). Enables cross-version interop testing.
|
||||
- `--participants N` — number of CustomImpl participants in group mode (default: 3)
|
||||
- `--reference-participants N` — number of ReferenceImpl (PeerConnection-based) participants in group mode (default: 0). Total = `--participants` + `--reference-participants`.
|
||||
- `--duration N` — test duration in seconds (default: 10)
|
||||
- `--drop-rate 0.0-1.0` — signaling packet drop probability
|
||||
- `--delay min-max` — signaling delay range in ms (e.g., `50-200`)
|
||||
- `--video` — enable H264 video with simulcast in group mode (both CustomImpl and ReferenceImpl participants)
|
||||
- `--churn-cycles N` — number of join/leave cycles in group-churn mode (default: 100)
|
||||
- `--network-scenario NAME` — network simulation test scenario (e.g., `step-down-up`). Group mode only.
|
||||
- `--quiet` — summary output only
|
||||
|
||||
### Modes
|
||||
- **P2P**: Direct loopback, `enableP2P=true`, no servers configured
|
||||
- **Reflector**: Routes through a Telegram UDP reflector, `enableP2P=false`, configures `RtcServer` with `login="reflector"` and random peer tags (16 bytes, byte 0 = `0x00` for caller, `0x01` for callee)
|
||||
- **Group**: In-process SFU with N participants using `GroupInstanceCustomImpl` and/or `GroupInstanceReferenceImpl`. The SFU is implemented in Go using Pion's low-level ICE/DTLS/SRTP/SCTP APIs (not PeerConnection), linked into the same process via CGo c-archive. Each participant gets a full ICE + DTLS + SRTP + SCTP transport stack over localhost UDP. Audio RTP is selectively forwarded between all participants. With `--video`, H264 video with 3-layer simulcast is enabled. Mixed-implementation groups (CustomImpl + ReferenceImpl) are supported via `--reference-participants`.
|
||||
- **Group Churn**: Stress test for participant join/leave dynamics. Creates a base group of N participants, then rapidly cycles an additional participant in and out `--churn-cycles` times (default 100). After churn, validates that the base group is healthy: all connected, all receiving audio, and if `--video` is enabled, all receiving video. Alternates between CustomImpl and ReferenceImpl for the cycling participant. The `--duration` controls the stabilization wait after churn completes.
|
||||
|
||||
## Project Structure
|
||||
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/` — CLI test tool (main.cpp, group_mode.cpp, group_participant.h/.cpp, group_churn_mode.h/.cpp, fake_video_source.h/.cpp, fake_video_sink.h, run-test.sh, run-local-test.sh, BUILD)
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/` — Go/Pion SFU library (sfu.go, participant.go, mux.go, go.mod/go.sum), built as c-archive via rules_go + Gazelle, linked into tgcalls_cli
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tgcalls/` — tgcalls library source
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tgcalls/group/` — group call implementations (GroupInstanceCustomImpl, GroupInstanceReferenceImpl, GroupNetworkManager, GroupJoinPayloadInternal)
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tgcalls/v2/` — v2 implementation (InstanceV2Impl, InstanceV2ReferenceImpl, InstanceV2CompatImpl, NativeNetworkingImpl, SignalingSctpConnection, SignalingTranslator)
|
||||
- `submodules/TgVoipWebrtc/BUILD` — contains `tgcalls_core` target (C++ only, macOS-native) and `TgVoipWebrtc` target (iOS, ObjC)
|
||||
- `third-party/webrtc/` — WebRTC source and BUILD
|
||||
- `third-party/webrtc/webrtc/net/dcsctp/` — dc-sctp (SCTP implementation)
|
||||
- `third-party/webrtc/webrtc/media/sctp/dcsctp_transport.cc` — WebRTC SCTP wrapper
|
||||
- `third-party/` — other dependencies (opus, libvpx, ffmpeg, boringssl, etc.)
|
||||
|
||||
## Code Style
|
||||
- **Naming**: PascalCase for types, camelCase for variables/methods
|
||||
- **Language**: C++17 for tgcalls code
|
||||
- **Formatting**: Standard C++ formatting
|
||||
|
||||
## Further Context
|
||||
|
||||
When working in these areas, additional `CLAUDE.md` files load automatically:
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/CLAUDE.md` — CLI test tool architecture (P2P/Reflector, Group), supported version matrix
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/CLAUDE.md` — Go SFU internals: build integration, bandwidth adaptation, transport-cc feedback, network simulation
|
||||
- `submodules/TgVoipWebrtc/CLAUDE.md` — tgcalls library internals: macOS/Linux build patches, SCTP signaling, InstanceV2CompatImpl, GroupInstanceCustomImpl/ReferenceImpl, video pitfalls, known issues
|
||||
+52
@@ -0,0 +1,52 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
# Multi-stage build for tgcalls_cli Linux container
|
||||
# Build: docker build -t tgcalls-test .
|
||||
# Run: docker run tgcalls-test --mode reflector --reflector 91.108.13.2:598 --duration 10
|
||||
|
||||
# ============================================================
|
||||
# Stage 1: Build
|
||||
# ============================================================
|
||||
FROM ubuntu:24.04 AS builder
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc g++ cmake meson ninja-build nasm make \
|
||||
autoconf automake libtool pkg-config python3 \
|
||||
unzip curl ca-certificates patch \
|
||||
zlib1g-dev libbz2-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
# Copy source tree
|
||||
COPY . .
|
||||
|
||||
# Always download Bazel for the container's architecture (host copy may be wrong arch)
|
||||
RUN ARCH=$(uname -m) && \
|
||||
if [ "$ARCH" = "x86_64" ]; then BAZEL_ARCH="x86_64"; \
|
||||
elif [ "$ARCH" = "aarch64" ]; then BAZEL_ARCH="arm64"; \
|
||||
else echo "Unsupported arch: $ARCH" && exit 1; fi && \
|
||||
curl -fL "https://github.com/bazelbuild/bazel/releases/download/8.4.2/bazel-8.4.2-linux-${BAZEL_ARCH}" \
|
||||
-o build-input/bazel-8.4.2-linux && \
|
||||
chmod +x build-input/bazel-8.4.2-linux
|
||||
|
||||
# Build with persistent Bazel cache
|
||||
RUN --mount=type=cache,target=/root/.cache/bazel \
|
||||
./build-input/bazel-8.4.2-linux build //submodules/TgVoipWebrtc/tgcalls/tools/cli:tgcalls_cli \
|
||||
--strategy=Genrule=standalone --spawn_strategy=standalone && \
|
||||
cp bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli /tmp/tgcalls_cli
|
||||
|
||||
# ============================================================
|
||||
# Stage 2: Runtime (minimal)
|
||||
# ============================================================
|
||||
FROM ubuntu:24.04
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=builder /tmp/tgcalls_cli /usr/local/bin/tgcalls_cli
|
||||
|
||||
ENTRYPOINT ["tgcalls_cli"]
|
||||
CMD ["--help"]
|
||||
@@ -0,0 +1,37 @@
|
||||
cc_binary(
|
||||
name = "tgcalls_cli",
|
||||
srcs = [
|
||||
"main.cpp",
|
||||
"group_mode.cpp",
|
||||
"group_mode.h",
|
||||
"group_participant.cpp",
|
||||
"group_participant.h",
|
||||
"group_churn_mode.cpp",
|
||||
"group_churn_mode.h",
|
||||
"fake_video_source.h",
|
||||
"fake_video_source.cpp",
|
||||
"fake_video_sink.h",
|
||||
],
|
||||
copts = [
|
||||
"-I{}/tgcalls/tgcalls".format("submodules/TgVoipWebrtc"),
|
||||
"-Ithird-party/webrtc/webrtc",
|
||||
"-Ithird-party/webrtc/dependencies",
|
||||
"-Ithird-party/webrtc/absl",
|
||||
"-Ithird-party/libyuv",
|
||||
"-DRTC_ENABLE_VP9",
|
||||
"-DNDEBUG",
|
||||
"-std=c++17",
|
||||
"-w",
|
||||
] + select({
|
||||
"@platforms//os:linux": ["-DWEBRTC_LINUX", "-DWEBRTC_POSIX"],
|
||||
"//conditions:default": ["-DWEBRTC_MAC", "-DWEBRTC_POSIX"],
|
||||
}),
|
||||
linkopts = select({
|
||||
"@platforms//os:linux": ["-lpthread", "-lm", "-ldl"],
|
||||
"//conditions:default": ["-framework", "CoreFoundation", "-framework", "Security"],
|
||||
}),
|
||||
deps = [
|
||||
"//submodules/TgVoipWebrtc:tgcalls_core",
|
||||
"//submodules/TgVoipWebrtc/tgcalls/tools/go_sfu",
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
# tgcalls CLI Test Tool
|
||||
|
||||
In-process test harness for tgcalls. See the root `CLAUDE.md` for build instructions, top-level CLI usage, and the CLI options reference.
|
||||
|
||||
## Supported Versions
|
||||
|
||||
| Version | Implementation | Notes |
|
||||
|---|---|---|
|
||||
| `14.0.0` | `InstanceV2CompatImpl` | WebRTC PeerConnection + V2Impl signaling. Cross-version interop with 7.0.0–13.0.0 |
|
||||
| `13.0.0` (default) | `InstanceV2Impl` | Also: 7.0.0, 8.0.0, 9.0.0, 12.0.0 |
|
||||
| `11.0.0` | `InstanceV2ReferenceImpl` | Also: 10.0.0. Uses WebRTC PeerConnection |
|
||||
| `5.0.0` | `InstanceImpl` (v1) | Also: 2.7.7. Legacy |
|
||||
|
||||
## Architecture (P2P/Reflector)
|
||||
- Two `tgcalls::Instance` objects (caller + callee) created via `Meta::Create(version, ...)`
|
||||
- Signaling bridged via `SignalingBridge` with configurable drop rate and delay
|
||||
- `FakeAudioDeviceModule` with `SineRecorder` (440Hz tone) and `NoOpRenderer` (audio discarded; validation via BWE)
|
||||
- `FakeInterface` platform implementation (pure C++, no iOS/ObjC deps)
|
||||
- Stats log validation: both caller and callee write `config.statsLogPath` with bitrate records; non-empty log with at least one non-zero BWE value is a success condition
|
||||
- On failure, full tgcalls internal logs (caller + callee) are dumped to stdout via `config.logPath`
|
||||
|
||||
## Architecture (Group)
|
||||
- N participants using `GroupInstanceCustomImpl` and/or `GroupInstanceReferenceImpl` connect to an in-process Go SFU
|
||||
- SFU uses Pion's low-level APIs (pion/ice, pion/dtls, pion/srtp, pion/sctp) — NOT PeerConnection
|
||||
- ICE: lite mode, loopback-only, UDP host candidates on 127.0.0.1. SFU uses `Dial` (controlling) for CustomImpl clients and `Accept` (controlled) for PeerConnection clients
|
||||
- DTLS: SFU acts as DTLS client (setup=active); GroupNetworkManager hardcodes SSL_SERVER for the tgcalls client
|
||||
- SRTP: AES-256-GCM (negotiated via DTLS-SRTP; GroupNetworkManager requires GCM suites)
|
||||
- SCTP: over DTLS, accepts data channel from client, reads Colibri messages, sends `ActiveAudioSsrcs` and `ActiveVideoSsrcs` notifications
|
||||
- RTP forwarding: audio RTP forwarded to all others unconditionally; video RTP forwarded only to receivers that have requested video from that sender (via `ReceiverVideoConstraints`)
|
||||
- SSRC tracking: SFU maintains `ssrcRegistry map[uint32]ssrcInfo` with kind (audio/video/video-rtx) and simulcast layer index, exposed via `GoSfu_QuerySsrc` and `GoSfu_QueryVideoSsrcs` CGo exports
|
||||
- SSRC discovery: SFU broadcasts `ActiveAudioSsrcs` and `ActiveVideoSsrcs` over data channel when participants connect
|
||||
- Video SSRC groups: parsed from join payload `"ssrc-groups"` field (SIM + FID semantics), stored per participant
|
||||
- Colibri video constraints: SFU parses `ReceiverVideoConstraints` from receivers, sends `SenderVideoConstraints` back to senders with `idealHeight`, and sends proactive PLI to trigger keyframes when a receiver first requests video
|
||||
- RTCP feedback: SFU demuxes SRTCP from the shared ICE transport (RFC 5761: byte[1] >= 200 && < 224), decrypts with per-participant SRTCP contexts, parses PLI/FIR, and forwards as new PLI to the sender. NACK is terminated (not forwarded).
|
||||
- Audio validation: `audioLevelsUpdated` callback tracks remote audio levels; success requires every participant to receive audio from at least one other participant (remote SSRC != 0, level > 0.05). The 440Hz sine tone arrives at ~0.126 level after SFU forwarding.
|
||||
- Video validation: `FakeVideoSink` (implements `rtc::VideoSinkInterface<VideoFrame>`) counts decoded frames per remote endpoint; success requires every participant to receive ≥1 frame from every other
|
||||
- Video signaling flow: SFU broadcasts `ActiveVideoSsrcs` over data channel → `dataChannelMessageReceived` callback fires in the app → app calls `setRequestedVideoChannels` → CustomImpl creates `IncomingVideoChannel` / ReferenceImpl adds recvonly video transceiver → both send `ReceiverVideoConstraints` → SFU sends `SenderVideoConstraints` + proactive PLI → sender produces keyframe → receiver decodes
|
||||
- `dataChannelMessageReceived` callback: added to `GroupInstanceDescriptor`, forwards all incoming Colibri data channel messages to the application. Used by the CLI test tool to react to `ActiveVideoSsrcs` and dynamically set up video channels — mirrors the real Telegram app's reactive flow
|
||||
- `FakeAudioDeviceModule` with `SineRecorder` (440Hz tone) and `NoOpRenderer` — same as P2P mode
|
||||
- `FakeVideoTrackSource` generates 1280x720 I420 frames at 30fps with per-participant color tint and frame counter (720p needed for 3 simulcast layers; 640x360 only allows 2 per WebRTC's `kSimulcastFormats`)
|
||||
- Group mode source: `submodules/TgVoipWebrtc/tgcalls/tools/cli/group_mode.cpp`
|
||||
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
#include "api/video/video_frame.h"
|
||||
#include "api/video/video_sink_interface.h"
|
||||
#include <atomic>
|
||||
|
||||
class FakeVideoSink : public rtc::VideoSinkInterface<webrtc::VideoFrame> {
|
||||
public:
|
||||
void OnFrame(const webrtc::VideoFrame& frame) override {
|
||||
frameCount_.fetch_add(1, std::memory_order_relaxed);
|
||||
int w = frame.width();
|
||||
int h = frame.height();
|
||||
lastWidth_.store(w, std::memory_order_relaxed);
|
||||
lastHeight_.store(h, std::memory_order_relaxed);
|
||||
}
|
||||
int lastWidth() const { return lastWidth_.load(std::memory_order_relaxed); }
|
||||
int lastHeight() const { return lastHeight_.load(std::memory_order_relaxed); }
|
||||
int frameCount() const {
|
||||
return frameCount_.load(std::memory_order_relaxed);
|
||||
}
|
||||
private:
|
||||
std::atomic<int> frameCount_{0};
|
||||
std::atomic<int> lastWidth_{0};
|
||||
std::atomic<int> lastHeight_{0};
|
||||
};
|
||||
@@ -0,0 +1,149 @@
|
||||
#include "fake_video_source.h"
|
||||
|
||||
#include "api/video/video_frame.h"
|
||||
#include "api/video/video_rotation.h"
|
||||
#include "rtc_base/time_utils.h"
|
||||
|
||||
#include <cstring>
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int kWidth = 1280;
|
||||
constexpr int kHeight = 720;
|
||||
constexpr int kFps = 30;
|
||||
constexpr uint8_t kBgY = 80; // dark background
|
||||
constexpr uint8_t kDigitY = 235; // white digits
|
||||
constexpr int kDigitW = 5;
|
||||
constexpr int kDigitH = 7;
|
||||
constexpr int kScale = 4;
|
||||
constexpr int kDigitSpacing = 2; // pixels between digits (scaled)
|
||||
constexpr int kMargin = 8; // top-left margin in pixels
|
||||
|
||||
// 5x7 bitmap font for digits 0-9. Each entry is 7 rows of 5-bit patterns.
|
||||
// MSB = leftmost pixel.
|
||||
static const uint8_t kDigitBitmaps[10][7] = {
|
||||
// 0
|
||||
{0b01110, 0b10001, 0b10011, 0b10101, 0b11001, 0b10001, 0b01110},
|
||||
// 1
|
||||
{0b00100, 0b01100, 0b00100, 0b00100, 0b00100, 0b00100, 0b01110},
|
||||
// 2
|
||||
{0b01110, 0b10001, 0b00001, 0b00010, 0b00100, 0b01000, 0b11111},
|
||||
// 3
|
||||
{0b11111, 0b00010, 0b00100, 0b00010, 0b00001, 0b10001, 0b01110},
|
||||
// 4
|
||||
{0b00010, 0b00110, 0b01010, 0b10010, 0b11111, 0b00010, 0b00010},
|
||||
// 5
|
||||
{0b11111, 0b10000, 0b11110, 0b00001, 0b00001, 0b10001, 0b01110},
|
||||
// 6
|
||||
{0b00110, 0b01000, 0b10000, 0b11110, 0b10001, 0b10001, 0b01110},
|
||||
// 7
|
||||
{0b11111, 0b00001, 0b00010, 0b00100, 0b01000, 0b01000, 0b01000},
|
||||
// 8
|
||||
{0b01110, 0b10001, 0b10001, 0b01110, 0b10001, 0b10001, 0b01110},
|
||||
// 9
|
||||
{0b01110, 0b10001, 0b10001, 0b01111, 0b00001, 0b00010, 0b01100},
|
||||
};
|
||||
|
||||
// 6 color tints cycling: red, green, blue, yellow, cyan, magenta
|
||||
// UV values for each tint (in I420, U=Cb, V=Cr; neutral=128)
|
||||
struct UVTint { uint8_t u; uint8_t v; };
|
||||
static const UVTint kTints[6] = {
|
||||
{90, 240}, // red
|
||||
{54, 34}, // green
|
||||
{240, 110}, // blue
|
||||
{16, 146}, // yellow
|
||||
{166, 16}, // cyan
|
||||
{166, 240}, // magenta
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
FakeVideoTrackSource::FakeVideoTrackSource(int participantId)
|
||||
: participantId_(participantId) {
|
||||
const auto& tint = kTints[participantId % 6];
|
||||
uTint_ = tint.u;
|
||||
vTint_ = tint.v;
|
||||
thread_ = std::thread(&FakeVideoTrackSource::GenerateThread, this);
|
||||
}
|
||||
|
||||
FakeVideoTrackSource::~FakeVideoTrackSource() {
|
||||
Stop();
|
||||
}
|
||||
|
||||
rtc::scoped_refptr<FakeVideoTrackSource> FakeVideoTrackSource::Create(int participantId) {
|
||||
return rtc::scoped_refptr<FakeVideoTrackSource>(
|
||||
new rtc::RefCountedObject<FakeVideoTrackSource>(participantId));
|
||||
}
|
||||
|
||||
void FakeVideoTrackSource::Stop() {
|
||||
if (running_.exchange(false)) {
|
||||
if (thread_.joinable()) {
|
||||
thread_.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FakeVideoTrackSource::GenerateThread() {
|
||||
int frameNumber = 0;
|
||||
while (running_.load(std::memory_order_relaxed)) {
|
||||
auto buffer = webrtc::I420Buffer::Create(kWidth, kHeight);
|
||||
|
||||
// Fill Y plane with dark background
|
||||
memset(buffer->MutableDataY(), kBgY, buffer->StrideY() * kHeight);
|
||||
|
||||
// Fill U plane with tint
|
||||
int uvHeight = (kHeight + 1) / 2;
|
||||
memset(buffer->MutableDataU(), uTint_, buffer->StrideU() * uvHeight);
|
||||
|
||||
// Fill V plane with tint
|
||||
memset(buffer->MutableDataV(), vTint_, buffer->StrideV() * uvHeight);
|
||||
|
||||
// Render frame counter digits
|
||||
RenderDigits(buffer->MutableDataY(), buffer->StrideY(), frameNumber);
|
||||
|
||||
auto frame = webrtc::VideoFrame::Builder()
|
||||
.set_video_frame_buffer(buffer)
|
||||
.set_rotation(webrtc::kVideoRotation_0)
|
||||
.set_timestamp_us(rtc::TimeMicros())
|
||||
.build();
|
||||
|
||||
OnFrame(frame);
|
||||
|
||||
++frameNumber;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000 / kFps));
|
||||
}
|
||||
}
|
||||
|
||||
void FakeVideoTrackSource::RenderDigits(uint8_t* yPlane, int strideY, int frameNumber) {
|
||||
// Convert frame number to decimal digits
|
||||
char numStr[16];
|
||||
snprintf(numStr, sizeof(numStr), "%d", frameNumber);
|
||||
int numDigits = static_cast<int>(strlen(numStr));
|
||||
|
||||
int xOffset = kMargin;
|
||||
for (int d = 0; d < numDigits; ++d) {
|
||||
int digit = numStr[d] - '0';
|
||||
const uint8_t* bitmap = kDigitBitmaps[digit];
|
||||
|
||||
for (int row = 0; row < kDigitH; ++row) {
|
||||
uint8_t rowBits = bitmap[row];
|
||||
for (int col = 0; col < kDigitW; ++col) {
|
||||
if (rowBits & (1 << (kDigitW - 1 - col))) {
|
||||
// Fill scaled pixel block
|
||||
int px = xOffset + col * kScale;
|
||||
int py = kMargin + row * kScale;
|
||||
for (int sy = 0; sy < kScale; ++sy) {
|
||||
for (int sx = 0; sx < kScale; ++sx) {
|
||||
int x = px + sx;
|
||||
int y = py + sy;
|
||||
if (x < kWidth && y < kHeight) {
|
||||
yPlane[y * strideY + x] = kDigitY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
xOffset += kDigitW * kScale + kDigitSpacing;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include "api/video/i420_buffer.h"
|
||||
#include "media/base/adapted_video_track_source.h"
|
||||
#include "rtc_base/ref_counted_object.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
|
||||
// Generates 640x360 I420 frames at 30fps with per-participant color tint
|
||||
// and an incrementing frame counter rendered as block digits.
|
||||
class FakeVideoTrackSource : public rtc::AdaptedVideoTrackSource {
|
||||
public:
|
||||
static rtc::scoped_refptr<FakeVideoTrackSource> Create(int participantId);
|
||||
|
||||
~FakeVideoTrackSource() override;
|
||||
|
||||
void Stop();
|
||||
|
||||
// VideoTrackSourceInterface
|
||||
SourceState state() const override { return kLive; }
|
||||
bool remote() const override { return false; }
|
||||
bool is_screencast() const override { return false; }
|
||||
absl::optional<bool> needs_denoising() const override { return false; }
|
||||
|
||||
protected:
|
||||
explicit FakeVideoTrackSource(int participantId);
|
||||
|
||||
private:
|
||||
void GenerateThread();
|
||||
void RenderDigits(uint8_t* yPlane, int strideY, int frameNumber);
|
||||
|
||||
int participantId_;
|
||||
uint8_t uTint_;
|
||||
uint8_t vTint_;
|
||||
std::atomic<bool> running_{true};
|
||||
std::thread thread_;
|
||||
};
|
||||
@@ -0,0 +1,205 @@
|
||||
#include "group_churn_mode.h"
|
||||
#include "group_participant.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
|
||||
// CGo header
|
||||
#include "submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go_sfu.h"
|
||||
|
||||
int runGroupChurnMode(
|
||||
int customParticipants,
|
||||
int referenceParticipants,
|
||||
int duration,
|
||||
bool quiet,
|
||||
bool video,
|
||||
int churnCycles
|
||||
) {
|
||||
gGroupQuiet = quiet;
|
||||
gGroupStartTime = std::chrono::steady_clock::now();
|
||||
|
||||
int baseCount = customParticipants + referenceParticipants;
|
||||
if (baseCount < 2) {
|
||||
fprintf(stderr, "Error: need at least 2 base participants total\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
groupLog("Churn", "initializing Go SFU...");
|
||||
|
||||
int rc = GoSfu_Init();
|
||||
if (rc != 0) {
|
||||
fprintf(stderr, "Error: GoSfu_Init failed with %d\n", rc);
|
||||
return 1;
|
||||
}
|
||||
|
||||
GoInt sfuHandle = GoSfu_Create();
|
||||
if (sfuHandle <= 0) {
|
||||
fprintf(stderr, "Error: GoSfu_Create failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
groupLog("Churn", "SFU handle=%lld, base=%d (custom=%d, ref=%d), cycles=%d, video=%s",
|
||||
(long long)sfuHandle, baseCount, customParticipants, referenceParticipants,
|
||||
churnCycles, video ? "yes" : "no");
|
||||
|
||||
auto threads = tgcalls::StaticThreads::getThreads();
|
||||
|
||||
// --- Phase 1: Create base group ---
|
||||
groupLog("Churn", "creating base group...");
|
||||
std::vector<std::unique_ptr<ParticipantState>> baseStates;
|
||||
bool anyFailed = false;
|
||||
|
||||
for (int i = 0; i < baseCount; ++i) {
|
||||
bool isReference = (i >= customParticipants);
|
||||
auto state = createParticipant(i, isReference, sfuHandle, threads, quiet, video, &baseStates);
|
||||
if (!state) {
|
||||
anyFailed = true;
|
||||
continue;
|
||||
}
|
||||
baseStates.push_back(std::move(state));
|
||||
}
|
||||
|
||||
// Wait for all base participants to connect
|
||||
groupLog("Churn", "waiting for base group connections...");
|
||||
auto waitStart = std::chrono::steady_clock::now();
|
||||
while (std::chrono::steady_clock::now() - waitStart < std::chrono::seconds(15)) {
|
||||
int connectedCount = 0;
|
||||
for (const auto& s : baseStates) {
|
||||
if (s->wasConnected.load()) connectedCount++;
|
||||
}
|
||||
if (connectedCount == (int)baseStates.size()) {
|
||||
groupLog("Churn", "all %d base participants connected", (int)baseStates.size());
|
||||
break;
|
||||
}
|
||||
groupLog("Churn", "base connected: %d/%d", connectedCount, (int)baseStates.size());
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
}
|
||||
|
||||
// Wait for audio to flow in base group
|
||||
groupLog("Churn", "waiting for base group audio...");
|
||||
waitStart = std::chrono::steady_clock::now();
|
||||
while (std::chrono::steady_clock::now() - waitStart < std::chrono::seconds(10)) {
|
||||
int audioCount = 0;
|
||||
for (const auto& s : baseStates) {
|
||||
if (s->receivedAudio.load()) audioCount++;
|
||||
}
|
||||
if (audioCount == (int)baseStates.size()) {
|
||||
groupLog("Churn", "all %d base participants receiving audio", (int)baseStates.size());
|
||||
break;
|
||||
}
|
||||
groupLog("Churn", "base audio: %d/%d", audioCount, (int)baseStates.size());
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
}
|
||||
|
||||
// --- Phase 2: Churn loop ---
|
||||
groupLog("Churn", "starting churn: %d cycles", churnCycles);
|
||||
int nextId = baseCount;
|
||||
int completedCycles = 0;
|
||||
|
||||
for (int cycle = 0; cycle < churnCycles; ++cycle) {
|
||||
bool isReference = (cycle % 2 == 1);
|
||||
int churnId = nextId++;
|
||||
|
||||
auto churner = createParticipant(churnId, isReference, sfuHandle, threads, quiet, video, &baseStates);
|
||||
if (!churner) {
|
||||
groupLog("Churn", "cycle %d: createParticipant failed for id=%d", cycle, churnId);
|
||||
anyFailed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Wait briefly for connection (up to 3s)
|
||||
auto connStart = std::chrono::steady_clock::now();
|
||||
while (std::chrono::steady_clock::now() - connStart < std::chrono::seconds(3)) {
|
||||
if (churner->wasConnected.load()) break;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
|
||||
if (!churner->wasConnected.load()) {
|
||||
groupLog("Churn", "cycle %d: churner %d did not connect (continuing anyway)", cycle, churnId);
|
||||
}
|
||||
|
||||
// Leave
|
||||
stopParticipant(churner.get(), sfuHandle);
|
||||
completedCycles++;
|
||||
|
||||
if ((cycle + 1) % 10 == 0) {
|
||||
groupLog("Churn", "progress: %d/%d cycles completed", cycle + 1, churnCycles);
|
||||
}
|
||||
}
|
||||
|
||||
groupLog("Churn", "churn complete: %d/%d cycles succeeded", completedCycles, churnCycles);
|
||||
|
||||
// --- Phase 3: Stabilize and validate ---
|
||||
groupLog("Churn", "stabilizing for %d seconds...", duration);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(duration));
|
||||
|
||||
auto result = validateGroupState(baseStates, video);
|
||||
|
||||
// --- Phase 4: Teardown ---
|
||||
groupLog("Churn", "stopping base participants...");
|
||||
|
||||
// Stop video sources
|
||||
for (auto& s : baseStates) {
|
||||
if (s->videoSource) {
|
||||
s->videoSource->Stop();
|
||||
}
|
||||
}
|
||||
|
||||
// Stop instances
|
||||
std::atomic<int> stopCount{0};
|
||||
std::mutex stopMutex;
|
||||
std::condition_variable stopCv;
|
||||
|
||||
for (const auto& s : baseStates) {
|
||||
if (s->instance) {
|
||||
int pid_local = s->id;
|
||||
s->instance->stop([&stopCount, &stopMutex, &stopCv, pid_local]() {
|
||||
groupLog("Churn", "base participant %d stopped", pid_local);
|
||||
stopCount.fetch_add(1);
|
||||
std::lock_guard<std::mutex> lock(stopMutex);
|
||||
stopCv.notify_all();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(stopMutex);
|
||||
stopCv.wait_for(lock, std::chrono::seconds(5), [&] {
|
||||
return stopCount.load() >= (int)baseStates.size();
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& s : baseStates) {
|
||||
s->instance.reset();
|
||||
}
|
||||
|
||||
GoSfu_Destroy(sfuHandle);
|
||||
GoSfu_Shutdown();
|
||||
|
||||
// Print summary
|
||||
bool success = result.success && !anyFailed && (completedCycles == churnCycles);
|
||||
|
||||
printf("\n=== Group Churn Test Summary ===\n");
|
||||
printf("Base participants: %d (custom=%d, reference=%d)\n",
|
||||
baseCount, customParticipants, referenceParticipants);
|
||||
printf("Churn cycles: %d/%d completed\n", completedCycles, churnCycles);
|
||||
printf("Video: %s\n", video ? "yes" : "no");
|
||||
printf("Stabilization: %ds\n", duration);
|
||||
printf("Base connected: %d/%d\n", result.connectedCount, result.totalParticipants);
|
||||
printf("Base audio received: %d/%d\n", result.audioReceivedCount, result.totalParticipants);
|
||||
if (video) {
|
||||
printf("Base video received: %d/%d\n", result.videoReceivedPairs, result.videoExpectedPairs);
|
||||
}
|
||||
printf("Result: %s\n", success ? "SUCCESS" : "FAILED");
|
||||
|
||||
// Clean up log files
|
||||
for (const auto& s : baseStates) {
|
||||
unlink(s->logPath.c_str());
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
fflush(stderr);
|
||||
_exit(success ? 0 : 1);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
int runGroupChurnMode(
|
||||
int customParticipants,
|
||||
int referenceParticipants,
|
||||
int duration,
|
||||
bool quiet,
|
||||
bool video,
|
||||
int churnCycles
|
||||
);
|
||||
@@ -0,0 +1,173 @@
|
||||
#include "group_mode.h"
|
||||
#include "group_participant.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
|
||||
// CGo header
|
||||
#include "submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go_sfu.h"
|
||||
|
||||
int runGroupMode(int customParticipants, int referenceParticipants, int duration, bool quiet, bool video, const std::string& networkScenario) {
|
||||
gGroupQuiet = quiet;
|
||||
gGroupStartTime = std::chrono::steady_clock::now();
|
||||
|
||||
int participants = customParticipants + referenceParticipants;
|
||||
if (participants < 2) {
|
||||
fprintf(stderr, "Error: need at least 2 participants total\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
groupLog("Group", "initializing Go SFU...");
|
||||
|
||||
int rc = GoSfu_Init();
|
||||
if (rc != 0) {
|
||||
fprintf(stderr, "Error: GoSfu_Init failed with %d\n", rc);
|
||||
return 1;
|
||||
}
|
||||
|
||||
GoInt sfuHandle = GoSfu_Create();
|
||||
if (sfuHandle <= 0) {
|
||||
fprintf(stderr, "Error: GoSfu_Create failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
groupLog("Group", "created SFU handle=%lld, custom=%d, reference=%d, duration=%ds",
|
||||
(long long)sfuHandle, customParticipants, referenceParticipants, duration);
|
||||
|
||||
auto threads = tgcalls::StaticThreads::getThreads();
|
||||
|
||||
// Create participants
|
||||
std::vector<std::unique_ptr<ParticipantState>> states;
|
||||
bool anyFailed = false;
|
||||
|
||||
for (int i = 0; i < participants; ++i) {
|
||||
bool isReference = (i >= customParticipants);
|
||||
auto state = createParticipant(i, isReference, sfuHandle, threads, quiet, video, &states);
|
||||
if (!state) {
|
||||
anyFailed = true;
|
||||
continue;
|
||||
}
|
||||
states.push_back(std::move(state));
|
||||
}
|
||||
|
||||
// Wait for all participants to connect
|
||||
groupLog("Group", "waiting for connections...");
|
||||
bool allConnected = false;
|
||||
auto waitStart = std::chrono::steady_clock::now();
|
||||
while (std::chrono::steady_clock::now() - waitStart < std::chrono::seconds(15)) {
|
||||
int connectedCount = 0;
|
||||
for (const auto& s : states) {
|
||||
if (s->wasConnected.load()) connectedCount++;
|
||||
}
|
||||
if (connectedCount == (int)states.size()) {
|
||||
allConnected = true;
|
||||
groupLog("Group", "all %d participants connected", (int)states.size());
|
||||
break;
|
||||
}
|
||||
groupLog("Group", "connected: %d/%d", connectedCount, (int)states.size());
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
}
|
||||
|
||||
if (!allConnected) {
|
||||
int connectedCount = 0;
|
||||
for (const auto& s : states) {
|
||||
if (s->wasConnected.load()) connectedCount++;
|
||||
}
|
||||
groupLog("Group", "connection timeout: %d/%d connected", connectedCount, (int)states.size());
|
||||
}
|
||||
|
||||
// Run for the specified duration, optionally with network scenario.
|
||||
if (!networkScenario.empty() && networkScenario == "step-down-up") {
|
||||
// Scenario: start uncapped, then step down, step up, uncap.
|
||||
// Split duration into 4 phases.
|
||||
int phase = std::max(duration / 4, 2);
|
||||
groupLog("Group", "network-scenario '%s': phase duration=%ds", networkScenario.c_str(), phase);
|
||||
|
||||
// Phase 1: uncapped (should be layer 2 on high BW).
|
||||
groupLog("Group", "phase 1: uncapped");
|
||||
std::this_thread::sleep_for(std::chrono::seconds(phase));
|
||||
|
||||
// Phase 2: cap to 80 kbps (should force downswitch to layer 0).
|
||||
groupLog("Group", "phase 2: cap 80kbps");
|
||||
for (const auto& s : states) {
|
||||
GoSfu_SetNetworkParams(sfuHandle, s->id, 1, 0, 0, 0.0, 80000);
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(phase));
|
||||
|
||||
// Phase 3: cap to 200 kbps (should allow upswitch to layer 1).
|
||||
groupLog("Group", "phase 3: cap 200kbps");
|
||||
for (const auto& s : states) {
|
||||
GoSfu_SetNetworkParams(sfuHandle, s->id, 1, 0, 0, 0.0, 200000);
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(phase));
|
||||
|
||||
// Phase 4: uncap (should allow upswitch to layer 2).
|
||||
groupLog("Group", "phase 4: uncapped");
|
||||
for (const auto& s : states) {
|
||||
GoSfu_SetNetworkParams(sfuHandle, s->id, 1, 0, 0, 0.0, 0);
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(phase));
|
||||
} else {
|
||||
groupLog("Group", "running for %d seconds...", duration);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(duration));
|
||||
}
|
||||
|
||||
// Stop all participants (using GoSfu_Destroy for bulk teardown)
|
||||
groupLog("Group", "stopping participants...");
|
||||
|
||||
// Stop video sources first
|
||||
for (auto& s : states) {
|
||||
if (s->videoSource) {
|
||||
s->videoSource->Stop();
|
||||
}
|
||||
}
|
||||
|
||||
// Stop instances
|
||||
std::atomic<int> stopCount{0};
|
||||
std::mutex stopMutex;
|
||||
std::condition_variable stopCv;
|
||||
|
||||
for (const auto& s : states) {
|
||||
if (s->instance) {
|
||||
int pid_local = s->id;
|
||||
s->instance->stop([&stopCount, &stopMutex, &stopCv, pid_local]() {
|
||||
groupLog("Group", "participant %d stopped", pid_local);
|
||||
stopCount.fetch_add(1);
|
||||
std::lock_guard<std::mutex> lock(stopMutex);
|
||||
stopCv.notify_all();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(stopMutex);
|
||||
stopCv.wait_for(lock, std::chrono::seconds(5), [&] {
|
||||
return stopCount.load() >= (int)states.size();
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& s : states) {
|
||||
s->instance.reset();
|
||||
}
|
||||
|
||||
// Destroy SFU
|
||||
GoSfu_Destroy(sfuHandle);
|
||||
GoSfu_Shutdown();
|
||||
|
||||
// Validate and print summary
|
||||
auto result = validateGroupState(states, video);
|
||||
bool success = printGroupSummary(customParticipants, referenceParticipants, duration, video, result, anyFailed);
|
||||
|
||||
// Clean up log files
|
||||
for (const auto& s : states) {
|
||||
unlink(s->logPath.c_str());
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
fflush(stderr);
|
||||
_exit(success ? 0 : 1);
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
int runGroupMode(int customParticipants, int referenceParticipants, int duration, bool quiet, bool video, const std::string& networkScenario = "");
|
||||
@@ -0,0 +1,442 @@
|
||||
#include "group_participant.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <condition_variable>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "third-party/json11.hpp"
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Globals
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
std::chrono::steady_clock::time_point gGroupStartTime = std::chrono::steady_clock::now();
|
||||
std::atomic<bool> gGroupQuiet{false};
|
||||
|
||||
double groupElapsed() {
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
return std::chrono::duration<double>(now - gGroupStartTime).count();
|
||||
}
|
||||
|
||||
void groupLog(const char* tag, const char* fmt, ...) {
|
||||
if (gGroupQuiet) return;
|
||||
char buf[512];
|
||||
va_list ap;
|
||||
va_start(ap, fmt);
|
||||
vsnprintf(buf, sizeof(buf), fmt, ap);
|
||||
va_end(ap);
|
||||
fprintf(stderr, "[%7.3f] %s: %s\n", groupElapsed(), tag, buf);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GroupSineRecorder
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
GroupSineRecorder::GroupSineRecorder() {
|
||||
buffer_.resize(kFrameSamples * kChannels);
|
||||
}
|
||||
|
||||
tgcalls::AudioFrame GroupSineRecorder::Record() {
|
||||
for (size_t i = 0; i < kFrameSamples; ++i) {
|
||||
double t = static_cast<double>(phase_) / kSampleRate;
|
||||
int16_t sample = static_cast<int16_t>(kAmplitude * std::sin(2.0 * M_PI * kFrequency * t));
|
||||
for (size_t ch = 0; ch < kChannels; ++ch) {
|
||||
buffer_[i * kChannels + ch] = sample;
|
||||
}
|
||||
++phase_;
|
||||
}
|
||||
|
||||
tgcalls::AudioFrame frame;
|
||||
frame.audio_samples = buffer_.data();
|
||||
frame.num_samples = kFrameSamples;
|
||||
frame.bytes_per_sample = sizeof(int16_t);
|
||||
frame.num_channels = kChannels;
|
||||
frame.samples_per_sec = kSampleRate;
|
||||
frame.elapsed_time_ms = 0;
|
||||
frame.ntp_time_ms = 0;
|
||||
return frame;
|
||||
}
|
||||
|
||||
int32_t GroupSineRecorder::WaitForUs() {
|
||||
return 10000; // 10ms
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GroupNoOpRenderer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
bool GroupNoOpRenderer::Render(const tgcalls::AudioFrame&) { return true; }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SimpleRequestMediaChannelDescriptionTask
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
void SimpleRequestMediaChannelDescriptionTask::cancel() {}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// createParticipant
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
std::unique_ptr<ParticipantState> createParticipant(
|
||||
int id,
|
||||
bool isReference,
|
||||
GoInt sfuHandle,
|
||||
std::shared_ptr<tgcalls::Threads> threads,
|
||||
bool quiet,
|
||||
bool video,
|
||||
std::vector<std::unique_ptr<ParticipantState>>* allStates
|
||||
) {
|
||||
auto state = std::make_unique<ParticipantState>();
|
||||
state->id = id;
|
||||
state->isReference = isReference;
|
||||
state->logPath = "/tmp/tgcalls_group_p" + std::to_string(id) + "_" + std::to_string(getpid()) + ".log";
|
||||
|
||||
std::string tag = "P" + std::to_string(id);
|
||||
|
||||
auto recorder = std::make_shared<GroupSineRecorder>();
|
||||
auto renderer = std::make_shared<GroupNoOpRenderer>();
|
||||
|
||||
ParticipantState* statePtr = state.get();
|
||||
GoInt sfuH = sfuHandle;
|
||||
|
||||
tgcalls::GroupInstanceDescriptor descriptor;
|
||||
descriptor.threads = threads;
|
||||
descriptor.config.need_log = true;
|
||||
descriptor.config.logPath = {state->logPath};
|
||||
descriptor.networkStateUpdated = [statePtr, tag](tgcalls::GroupNetworkState networkState) {
|
||||
groupLog(tag.c_str(), "network state: connected=%s", networkState.isConnected ? "true" : "false");
|
||||
statePtr->connected.store(networkState.isConnected);
|
||||
if (networkState.isConnected) {
|
||||
statePtr->wasConnected.store(true);
|
||||
}
|
||||
};
|
||||
descriptor.audioLevelsUpdated = [statePtr, tag](tgcalls::GroupLevelsUpdate const &update) {
|
||||
for (const auto& level : update.updates) {
|
||||
if (level.value.level > 0.01f) {
|
||||
groupLog(tag.c_str(), "audio level: ssrc=%u level=%.3f voice=%d",
|
||||
level.ssrc, level.value.level, level.value.voice);
|
||||
}
|
||||
if (level.ssrc != 0 && level.value.level > 0.05f) {
|
||||
statePtr->receivedAudio.store(true);
|
||||
}
|
||||
}
|
||||
};
|
||||
descriptor.createAudioDeviceModule = tgcalls::FakeAudioDeviceModule::Creator(
|
||||
renderer, recorder,
|
||||
tgcalls::FakeAudioDeviceModule::Options{.samples_per_sec = 48000, .num_channels = 2}
|
||||
);
|
||||
|
||||
descriptor.requestMediaChannelDescriptions = [sfuH, tag, allStates](
|
||||
std::vector<uint32_t> const &ssrcs,
|
||||
std::function<void(std::vector<tgcalls::MediaChannelDescription> &&)> callback
|
||||
) -> std::shared_ptr<tgcalls::RequestMediaChannelDescriptionTask> {
|
||||
std::set<uint32_t> audioSsrcs;
|
||||
for (const auto& s : *allStates) {
|
||||
if (s->audioSsrc != 0) audioSsrcs.insert(s->audioSsrc);
|
||||
}
|
||||
std::vector<tgcalls::MediaChannelDescription> descriptions;
|
||||
for (uint32_t ssrc : ssrcs) {
|
||||
GoInt ownerID = GoSfu_QuerySsrc(sfuH, (GoUint)ssrc);
|
||||
bool isAudio = audioSsrcs.count(ssrc) > 0;
|
||||
groupLog(tag.c_str(), "requestMediaChannelDescriptions: ssrc=%u -> owner=%lld type=%s",
|
||||
ssrc, (long long)ownerID, isAudio ? "audio" : "video");
|
||||
tgcalls::MediaChannelDescription desc;
|
||||
desc.type = isAudio ? tgcalls::MediaChannelDescription::Type::Audio
|
||||
: tgcalls::MediaChannelDescription::Type::Video;
|
||||
desc.audioSsrc = ssrc;
|
||||
desc.userId = ownerID;
|
||||
descriptions.push_back(std::move(desc));
|
||||
}
|
||||
callback(std::move(descriptions));
|
||||
return std::make_shared<SimpleRequestMediaChannelDescriptionTask>();
|
||||
};
|
||||
|
||||
descriptor.outgoingAudioBitrateKbit = 32;
|
||||
descriptor.disableIncomingChannels = false;
|
||||
descriptor.useDummyChannel = true;
|
||||
|
||||
// Video configuration
|
||||
if (video) {
|
||||
auto videoSource = FakeVideoTrackSource::Create(id);
|
||||
state->videoSource = videoSource;
|
||||
state->endpointId = std::to_string(id);
|
||||
descriptor.videoContentType = tgcalls::VideoContentType::Generic;
|
||||
descriptor.videoCodecPreferences = {tgcalls::VideoCodecName::H264};
|
||||
// Set the outgoing video min bitrate to 600 kbps so the sender's
|
||||
// BWE floor is high enough to activate all 3 simulcast layers
|
||||
// (audio 32k + L0 min 50k + L1 min 100k + L2 min 300k = 482k).
|
||||
// On localhost, delay-based BWE over the loopback pacer has been
|
||||
// observed to drift down to ~80 kbps, keeping L2 disabled. Clamping
|
||||
// the min forces the encoder to keep L2 producing.
|
||||
descriptor.minOutgoingVideoBitrateKbit = 600;
|
||||
descriptor.getVideoSource = [videoSource]() -> webrtc::scoped_refptr<webrtc::VideoTrackSourceInterface> {
|
||||
return videoSource;
|
||||
};
|
||||
|
||||
descriptor.dataChannelMessageReceived = [statePtr, sfuH, tag](std::string const &message) {
|
||||
std::string parseErr;
|
||||
auto json = json11::Json::parse(message, parseErr);
|
||||
if (!parseErr.empty() || !json.is_object()) return;
|
||||
auto cls = json["colibriClass"].string_value();
|
||||
if (cls != "ActiveVideoSsrcs") return;
|
||||
|
||||
auto ssrcsArray = json["ssrcs"].array_items();
|
||||
if (ssrcsArray.empty()) return;
|
||||
|
||||
std::vector<tgcalls::VideoChannelDescription> videoChannels;
|
||||
for (const auto& entry : ssrcsArray) {
|
||||
std::string endpointId = entry["endpointId"].string_value();
|
||||
if (endpointId == statePtr->endpointId) continue;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(statePtr->videoSinksMutex);
|
||||
if (statePtr->videoSinks.count(endpointId) > 0) continue;
|
||||
}
|
||||
|
||||
int remoteId = 0;
|
||||
if (sscanf(endpointId.c_str(), "%d", &remoteId) != 1) continue;
|
||||
|
||||
char* ssrcsRaw = GoSfu_QueryVideoSsrcs(sfuH, (GoInt)remoteId);
|
||||
if (!ssrcsRaw) continue;
|
||||
std::string ssrcsJson(ssrcsRaw);
|
||||
GoSfu_Free(ssrcsRaw);
|
||||
|
||||
std::string err2;
|
||||
auto layers = json11::Json::parse(ssrcsJson, err2);
|
||||
if (!err2.empty() || !layers.is_array() || layers.array_items().empty()) continue;
|
||||
|
||||
tgcalls::VideoChannelDescription desc;
|
||||
desc.audioSsrc = 0;
|
||||
desc.userId = remoteId;
|
||||
desc.endpointId = endpointId;
|
||||
desc.maxQuality = tgcalls::VideoChannelDescription::Quality::Full;
|
||||
desc.minQuality = tgcalls::VideoChannelDescription::Quality::Full;
|
||||
|
||||
tgcalls::MediaSsrcGroup simGroup;
|
||||
simGroup.semantics = "SIM";
|
||||
for (const auto& layer : layers.array_items()) {
|
||||
uint32_t ssrc = static_cast<uint32_t>(static_cast<int64_t>(layer["ssrc"].number_value()));
|
||||
uint32_t fidSsrc = static_cast<uint32_t>(static_cast<int64_t>(layer["fidSsrc"].number_value()));
|
||||
if (ssrc == 0) continue;
|
||||
simGroup.ssrcs.push_back(ssrc);
|
||||
if (fidSsrc != 0) {
|
||||
tgcalls::MediaSsrcGroup fidGroup;
|
||||
fidGroup.semantics = "FID";
|
||||
fidGroup.ssrcs = {ssrc, fidSsrc};
|
||||
desc.ssrcGroups.push_back(std::move(fidGroup));
|
||||
}
|
||||
}
|
||||
desc.ssrcGroups.insert(desc.ssrcGroups.begin(), std::move(simGroup));
|
||||
videoChannels.push_back(std::move(desc));
|
||||
|
||||
auto sink = std::make_shared<FakeVideoSink>();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(statePtr->videoSinksMutex);
|
||||
statePtr->videoSinks[endpointId] = sink;
|
||||
}
|
||||
statePtr->instance->addIncomingVideoOutput(
|
||||
endpointId,
|
||||
std::weak_ptr<rtc::VideoSinkInterface<webrtc::VideoFrame>>(sink));
|
||||
|
||||
groupLog(tag.c_str(), "ActiveVideoSsrcs: adding video channel for endpoint %s", endpointId.c_str());
|
||||
}
|
||||
|
||||
if (!videoChannels.empty()) {
|
||||
statePtr->instance->setRequestedVideoChannels(std::move(videoChannels));
|
||||
}
|
||||
};
|
||||
} else {
|
||||
descriptor.videoContentType = tgcalls::VideoContentType::None;
|
||||
}
|
||||
|
||||
// Create instance
|
||||
if (isReference) {
|
||||
state->instance = std::make_unique<tgcalls::GroupInstanceReferenceImpl>(std::move(descriptor));
|
||||
groupLog(tag.c_str(), "created GroupInstanceReferenceImpl");
|
||||
} else {
|
||||
state->instance = std::make_unique<tgcalls::GroupInstanceCustomImpl>(std::move(descriptor));
|
||||
groupLog(tag.c_str(), "created GroupInstanceCustomImpl");
|
||||
}
|
||||
|
||||
// Set connection mode
|
||||
state->instance->setConnectionMode(
|
||||
tgcalls::GroupConnectionMode::GroupConnectionModeRtc, false, false);
|
||||
|
||||
// Emit join payload
|
||||
std::mutex joinMutex;
|
||||
std::condition_variable joinCv;
|
||||
bool joinReady = false;
|
||||
std::string joinJson;
|
||||
uint32_t joinSsrc = 0;
|
||||
|
||||
state->instance->emitJoinPayload([&](tgcalls::GroupJoinPayload const &payload) {
|
||||
std::lock_guard<std::mutex> lock(joinMutex);
|
||||
joinJson = payload.json;
|
||||
joinSsrc = payload.audioSsrc;
|
||||
joinReady = true;
|
||||
joinCv.notify_one();
|
||||
});
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(joinMutex);
|
||||
if (!joinCv.wait_for(lock, std::chrono::seconds(5), [&] { return joinReady; })) {
|
||||
fprintf(stderr, "Error: emitJoinPayload timed out for participant %d\n", id);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
state->audioSsrc = joinSsrc;
|
||||
groupLog(tag.c_str(), "join payload ready: ssrc=%u, json=%zu bytes", joinSsrc, joinJson.size());
|
||||
|
||||
// Join SFU
|
||||
GoInt iceControlling = isReference ? 0 : 1;
|
||||
char* responseRaw = GoSfu_Join(sfuHandle, (GoInt)id, const_cast<char*>(joinJson.c_str()), iceControlling);
|
||||
if (!responseRaw) {
|
||||
fprintf(stderr, "Error: GoSfu_Join returned null for participant %d\n", id);
|
||||
return nullptr;
|
||||
}
|
||||
std::string response(responseRaw);
|
||||
GoSfu_Free(responseRaw);
|
||||
|
||||
if (response.find("\"error\"") != std::string::npos) {
|
||||
fprintf(stderr, "Error: GoSfu_Join failed for participant %d: %s\n", id, response.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
groupLog(tag.c_str(), "SFU join response: %zu bytes", response.size());
|
||||
|
||||
state->instance->setJoinResponsePayload(response);
|
||||
state->instance->setIsMuted(false);
|
||||
|
||||
groupLog(tag.c_str(), "joined and unmuted");
|
||||
return state;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// stopParticipant
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
void stopParticipant(ParticipantState* state, GoInt sfuHandle) {
|
||||
if (!state || !state->instance) return;
|
||||
|
||||
std::string tag = "P" + std::to_string(state->id);
|
||||
|
||||
// Remove from SFU first so broadcasts go out to remaining participants.
|
||||
GoInt rc = GoSfu_Leave(sfuHandle, (GoInt)state->id);
|
||||
if (rc != 0) {
|
||||
groupLog(tag.c_str(), "GoSfu_Leave returned %lld (may already be removed)", (long long)rc);
|
||||
}
|
||||
|
||||
// Stop video source.
|
||||
if (state->videoSource) {
|
||||
state->videoSource->Stop();
|
||||
}
|
||||
|
||||
// Stop instance with timeout. Heap-allocate sync state so the stop callback
|
||||
// is safe even if it fires after the 5s timeout (avoids stack-frame UB).
|
||||
struct StopState {
|
||||
std::mutex mu;
|
||||
std::condition_variable cv;
|
||||
std::atomic<bool> done{false};
|
||||
};
|
||||
auto stopState = std::make_shared<StopState>();
|
||||
|
||||
state->instance->stop([stopState]() {
|
||||
stopState->done.store(true);
|
||||
std::lock_guard<std::mutex> lock(stopState->mu);
|
||||
stopState->cv.notify_all();
|
||||
});
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(stopState->mu);
|
||||
stopState->cv.wait_for(lock, std::chrono::seconds(5), [&] { return stopState->done.load(); });
|
||||
}
|
||||
|
||||
state->instance.reset();
|
||||
|
||||
// Clean up log file.
|
||||
unlink(state->logPath.c_str());
|
||||
|
||||
groupLog(tag.c_str(), "stopped and cleaned up");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// validateGroupState
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
GroupValidationResult validateGroupState(
|
||||
const std::vector<std::unique_ptr<ParticipantState>>& states,
|
||||
bool video
|
||||
) {
|
||||
GroupValidationResult result{};
|
||||
result.totalParticipants = static_cast<int>(states.size());
|
||||
|
||||
for (const auto& s : states) {
|
||||
if (s->wasConnected.load()) result.connectedCount++;
|
||||
if (s->receivedAudio.load()) result.audioReceivedCount++;
|
||||
}
|
||||
|
||||
if (video) {
|
||||
int videoParticipants = 0;
|
||||
for (const auto& s : states) {
|
||||
if (s->videoSource) videoParticipants++;
|
||||
}
|
||||
result.videoExpectedPairs = videoParticipants * (videoParticipants - 1);
|
||||
|
||||
for (const auto& s : states) {
|
||||
std::lock_guard<std::mutex> lock(s->videoSinksMutex);
|
||||
for (const auto& [endpointId, sink] : s->videoSinks) {
|
||||
int frames = sink->frameCount();
|
||||
if (frames > 0) {
|
||||
result.videoReceivedPairs++;
|
||||
}
|
||||
groupLog("Validate", "P%d <- endpoint %s: %d video frames (%dx%d)",
|
||||
s->id, endpointId.c_str(), frames,
|
||||
sink->lastWidth(), sink->lastHeight());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.success = (result.connectedCount == result.totalParticipants &&
|
||||
result.audioReceivedCount == result.totalParticipants);
|
||||
if (video && result.videoExpectedPairs > 0) {
|
||||
result.success = result.success && (result.videoReceivedPairs >= result.videoExpectedPairs);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// printGroupSummary
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
bool printGroupSummary(
|
||||
int customParticipants,
|
||||
int referenceParticipants,
|
||||
int duration,
|
||||
bool video,
|
||||
const GroupValidationResult& result,
|
||||
bool anyFailed
|
||||
) {
|
||||
bool success = result.success && !anyFailed;
|
||||
|
||||
printf("\n=== Group Call Summary ===\n");
|
||||
printf("Custom participants: %d\n", customParticipants);
|
||||
printf("Reference participants: %d\n", referenceParticipants);
|
||||
printf("Total participants: %d\n", result.totalParticipants);
|
||||
printf("Duration: %ds\n", duration);
|
||||
printf("SFU: Go/Pion (in-process)\n");
|
||||
printf("Connected: %d/%d\n", result.connectedCount, result.totalParticipants);
|
||||
printf("Audio received: %d/%d\n", result.audioReceivedCount, result.totalParticipants);
|
||||
if (video) {
|
||||
printf("Video received: %d/%d\n", result.videoReceivedPairs, result.videoExpectedPairs);
|
||||
}
|
||||
printf("Result: %s\n", success ? "SUCCESS" : "FAILED");
|
||||
|
||||
return success;
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "group/GroupInstanceCustomImpl.h"
|
||||
#include "group/GroupInstanceImpl.h"
|
||||
#include "group/GroupInstanceReferenceImpl.h"
|
||||
#include "FakeAudioDeviceModule.h"
|
||||
#include "StaticThreads.h"
|
||||
#include "AudioFrame.h"
|
||||
#include "fake_video_source.h"
|
||||
#include "fake_video_sink.h"
|
||||
|
||||
// CGo header
|
||||
#include "submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go_sfu.h"
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Logging helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
extern std::chrono::steady_clock::time_point gGroupStartTime;
|
||||
extern std::atomic<bool> gGroupQuiet;
|
||||
|
||||
double groupElapsed();
|
||||
void groupLog(const char* tag, const char* fmt, ...);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GroupSineRecorder - generates 440 Hz sine tone
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class GroupSineRecorder : public tgcalls::FakeAudioDeviceModule::Recorder {
|
||||
public:
|
||||
GroupSineRecorder();
|
||||
tgcalls::AudioFrame Record() override;
|
||||
int32_t WaitForUs() override;
|
||||
|
||||
private:
|
||||
static constexpr size_t kSampleRate = 48000;
|
||||
static constexpr size_t kChannels = 2;
|
||||
static constexpr size_t kFrameSamples = 480;
|
||||
static constexpr double kFrequency = 440.0;
|
||||
static constexpr double kAmplitude = 3000.0;
|
||||
|
||||
std::vector<int16_t> buffer_;
|
||||
uint64_t phase_ = 0;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GroupNoOpRenderer - discards received audio
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class GroupNoOpRenderer : public tgcalls::FakeAudioDeviceModule::Renderer {
|
||||
public:
|
||||
bool Render(const tgcalls::AudioFrame&) override;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SimpleRequestMediaChannelDescriptionTask
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class SimpleRequestMediaChannelDescriptionTask : public tgcalls::RequestMediaChannelDescriptionTask {
|
||||
public:
|
||||
void cancel() override;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ParticipantState
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct ParticipantState {
|
||||
int id;
|
||||
bool isReference;
|
||||
std::unique_ptr<tgcalls::GroupInstanceInterface> instance;
|
||||
std::atomic<bool> connected{false};
|
||||
std::atomic<bool> wasConnected{false};
|
||||
std::atomic<bool> receivedAudio{false};
|
||||
uint32_t audioSsrc{0};
|
||||
std::string logPath;
|
||||
|
||||
// Video fields
|
||||
std::string endpointId;
|
||||
rtc::scoped_refptr<FakeVideoTrackSource> videoSource;
|
||||
std::mutex videoSinksMutex;
|
||||
std::map<std::string, std::shared_ptr<FakeVideoSink>> videoSinks;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GroupValidationResult
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct GroupValidationResult {
|
||||
int totalParticipants;
|
||||
int connectedCount;
|
||||
int audioReceivedCount;
|
||||
int videoReceivedPairs;
|
||||
int videoExpectedPairs;
|
||||
bool success;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Participant lifecycle functions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Creates a fully initialized participant: builds descriptor, creates instance,
|
||||
// joins SFU, sets join response, unmutes. Returns nullptr on failure.
|
||||
std::unique_ptr<ParticipantState> createParticipant(
|
||||
int id,
|
||||
bool isReference,
|
||||
GoInt sfuHandle,
|
||||
std::shared_ptr<tgcalls::Threads> threads,
|
||||
bool quiet,
|
||||
bool video,
|
||||
std::vector<std::unique_ptr<ParticipantState>>* allStates
|
||||
);
|
||||
|
||||
// Clean teardown: GoSfu_Leave, stop video, stop instance, reset.
|
||||
void stopParticipant(ParticipantState* state, GoInt sfuHandle);
|
||||
|
||||
// Validates group state: connection, audio, video. Returns result struct.
|
||||
GroupValidationResult validateGroupState(
|
||||
const std::vector<std::unique_ptr<ParticipantState>>& states,
|
||||
bool video
|
||||
);
|
||||
|
||||
// Prints a group call summary to stdout. Returns the success boolean.
|
||||
bool printGroupSummary(
|
||||
int customParticipants,
|
||||
int referenceParticipants,
|
||||
int duration,
|
||||
bool video,
|
||||
const GroupValidationResult& result,
|
||||
bool anyFailed
|
||||
);
|
||||
@@ -0,0 +1,602 @@
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <condition_variable>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <random>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
#include "group_mode.h"
|
||||
#include "group_churn_mode.h"
|
||||
#include "Instance.h"
|
||||
#include "FakeAudioDeviceModule.h"
|
||||
#include "VideoCaptureInterface.h"
|
||||
#include "v2/InstanceV2Impl.h"
|
||||
#include "v2/InstanceV2CompatImpl.h"
|
||||
#include "v2/InstanceV2ReferenceImpl.h"
|
||||
|
||||
#include "modules/audio_device/include/audio_device.h"
|
||||
#include "api/task_queue/task_queue_factory.h"
|
||||
|
||||
// Stub: AudioDeviceModule::Create is referenced by InstanceV2Impl as a fallback
|
||||
// but never called when createAudioDeviceModule is provided in the Descriptor.
|
||||
namespace webrtc {
|
||||
rtc::scoped_refptr<AudioDeviceModule> AudioDeviceModule::Create(
|
||||
AudioDeviceModule::AudioLayer audio_layer,
|
||||
TaskQueueFactory* task_queue_factory) {
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace webrtc
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static auto gStartTime = std::chrono::steady_clock::now();
|
||||
|
||||
static double elapsed() {
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
return std::chrono::duration<double>(now - gStartTime).count();
|
||||
}
|
||||
|
||||
static bool gQuiet = false;
|
||||
|
||||
static void logMsg(const char* role, const char* fmt, ...) {
|
||||
if (gQuiet) return;
|
||||
char buf[512];
|
||||
va_list ap;
|
||||
va_start(ap, fmt);
|
||||
vsnprintf(buf, sizeof(buf), fmt, ap);
|
||||
va_end(ap);
|
||||
fprintf(stderr, "[%7.3f] %s: %s\n", elapsed(), role, buf);
|
||||
}
|
||||
|
||||
static const char* stateName(tgcalls::State s) {
|
||||
switch (s) {
|
||||
case tgcalls::State::WaitInit: return "WaitInit";
|
||||
case tgcalls::State::WaitInitAck: return "WaitInitAck";
|
||||
case tgcalls::State::Established: return "Established";
|
||||
case tgcalls::State::Failed: return "Failed";
|
||||
case tgcalls::State::Reconnecting:return "Reconnecting";
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
static std::string hexEncode(const std::array<uint8_t, 16>& data) {
|
||||
char buf[33];
|
||||
for (size_t i = 0; i < 16; ++i) {
|
||||
snprintf(buf + i * 2, 3, "%02x", data[i]);
|
||||
}
|
||||
return std::string(buf, 32);
|
||||
}
|
||||
|
||||
static tgcalls::RtcServer makeReflectorServer(const std::string& host, uint16_t port,
|
||||
const std::array<uint8_t, 16>& peerTag) {
|
||||
tgcalls::RtcServer server;
|
||||
server.id = 1;
|
||||
server.host = host;
|
||||
server.port = port;
|
||||
server.login = "reflector";
|
||||
server.password = hexEncode(peerTag);
|
||||
server.isTurn = true;
|
||||
server.isTcp = false;
|
||||
return server;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SineRecorder - generates 440 Hz sine tone
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class SineRecorder : public tgcalls::FakeAudioDeviceModule::Recorder {
|
||||
public:
|
||||
SineRecorder() {
|
||||
buffer_.resize(kFrameSamples * kChannels);
|
||||
}
|
||||
|
||||
tgcalls::AudioFrame Record() override {
|
||||
for (size_t i = 0; i < kFrameSamples; ++i) {
|
||||
double t = static_cast<double>(phase_) / kSampleRate;
|
||||
int16_t sample = static_cast<int16_t>(kAmplitude * std::sin(2.0 * M_PI * kFrequency * t));
|
||||
for (size_t ch = 0; ch < kChannels; ++ch) {
|
||||
buffer_[i * kChannels + ch] = sample;
|
||||
}
|
||||
++phase_;
|
||||
}
|
||||
|
||||
tgcalls::AudioFrame frame;
|
||||
frame.audio_samples = buffer_.data();
|
||||
frame.num_samples = kFrameSamples;
|
||||
frame.bytes_per_sample = sizeof(int16_t);
|
||||
frame.num_channels = kChannels;
|
||||
frame.samples_per_sec = kSampleRate;
|
||||
frame.elapsed_time_ms = 0;
|
||||
frame.ntp_time_ms = 0;
|
||||
return frame;
|
||||
}
|
||||
|
||||
int32_t WaitForUs() override {
|
||||
return 10000; // 10ms
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr size_t kSampleRate = 48000;
|
||||
static constexpr size_t kChannels = 2;
|
||||
static constexpr size_t kFrameSamples = 480; // 10ms at 48kHz
|
||||
static constexpr double kFrequency = 440.0;
|
||||
static constexpr double kAmplitude = 3000.0;
|
||||
|
||||
std::vector<int16_t> buffer_;
|
||||
uint64_t phase_ = 0;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NoOpRenderer - discards received audio (validation is done via BWE stats)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
class NoOpRenderer : public tgcalls::FakeAudioDeviceModule::Renderer {
|
||||
public:
|
||||
bool Render(const tgcalls::AudioFrame&) override { return true; }
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SignalingBridge
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct SignalingBridge {
|
||||
std::mutex mutex;
|
||||
std::shared_ptr<tgcalls::Instance> caller;
|
||||
std::shared_ptr<tgcalls::Instance> callee;
|
||||
|
||||
// Network simulation
|
||||
double dropRate = 0.0;
|
||||
int delayMinMs = 0;
|
||||
int delayMaxMs = 0;
|
||||
std::mt19937 rng{std::random_device{}()};
|
||||
|
||||
void deliver(const char* fromRole, const std::vector<uint8_t>& data,
|
||||
std::shared_ptr<tgcalls::Instance>& target) {
|
||||
if (dropRate > 0.0) {
|
||||
std::uniform_real_distribution<double> dropDist(0.0, 1.0);
|
||||
if (dropDist(rng) < dropRate) {
|
||||
logMsg(fromRole, "signaling DROPPED (%zu bytes)", data.size());
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (delayMaxMs > 0) {
|
||||
std::uniform_int_distribution<int> delayDist(delayMinMs, delayMaxMs);
|
||||
int delayMs = delayDist(rng);
|
||||
if (delayMs > 0) {
|
||||
logMsg(fromRole, "signaling delayed %dms (%zu bytes)", delayMs, data.size());
|
||||
auto dataCopy = data;
|
||||
auto targetWeak = std::weak_ptr<tgcalls::Instance>(target);
|
||||
std::thread([dataCopy, targetWeak, delayMs]() {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(delayMs));
|
||||
if (auto t = targetWeak.lock()) {
|
||||
t->receiveSignalingData(dataCopy);
|
||||
}
|
||||
}).detach();
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (target) {
|
||||
target->receiveSignalingData(data);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CallState
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct CallState {
|
||||
std::mutex mutex;
|
||||
tgcalls::State callerState = tgcalls::State::WaitInit;
|
||||
tgcalls::State calleeState = tgcalls::State::WaitInit;
|
||||
double establishedAt = -1.0;
|
||||
std::vector<std::string> errors;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// main
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
int duration = 10;
|
||||
std::string mode;
|
||||
std::string reflectorAddr;
|
||||
std::string reflectorList;
|
||||
std::string version = "13.0.0";
|
||||
std::string version2;
|
||||
double dropRate = 0.0;
|
||||
int delayMinMs = 0;
|
||||
int delayMaxMs = 0;
|
||||
int participants = 3;
|
||||
int referenceParticipants = 0;
|
||||
bool enableVideo = false;
|
||||
int churnCycles = 100;
|
||||
std::string networkScenario;
|
||||
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
if (std::string(argv[i]) == "--duration" && i + 1 < argc) {
|
||||
duration = std::atoi(argv[++i]);
|
||||
} else if (std::string(argv[i]) == "--quiet") {
|
||||
gQuiet = true;
|
||||
} else if (std::string(argv[i]) == "--mode" && i + 1 < argc) {
|
||||
mode = argv[++i];
|
||||
} else if (std::string(argv[i]) == "--reflector" && i + 1 < argc) {
|
||||
reflectorAddr = argv[++i];
|
||||
} else if (std::string(argv[i]) == "--reflector-list" && i + 1 < argc) {
|
||||
reflectorList = argv[++i];
|
||||
} else if (std::string(argv[i]) == "--drop-rate" && i + 1 < argc) {
|
||||
dropRate = std::atof(argv[++i]);
|
||||
} else if (std::string(argv[i]) == "--version" && i + 1 < argc) {
|
||||
version = argv[++i];
|
||||
} else if (std::string(argv[i]) == "--version2" && i + 1 < argc) {
|
||||
version2 = argv[++i];
|
||||
} else if (std::string(argv[i]) == "--participants" && i + 1 < argc) {
|
||||
participants = std::atoi(argv[++i]);
|
||||
} else if (std::string(argv[i]) == "--reference-participants" && i + 1 < argc) {
|
||||
referenceParticipants = std::atoi(argv[++i]);
|
||||
} else if (std::string(argv[i]) == "--video") {
|
||||
enableVideo = true;
|
||||
} else if (std::string(argv[i]) == "--churn-cycles" && i + 1 < argc) {
|
||||
churnCycles = std::atoi(argv[++i]);
|
||||
} else if (std::string(argv[i]) == "--network-scenario" && i + 1 < argc) {
|
||||
networkScenario = argv[++i];
|
||||
} else if (std::string(argv[i]) == "--delay" && i + 1 < argc) {
|
||||
std::string delayStr = argv[++i];
|
||||
auto dashPos = delayStr.find('-');
|
||||
if (dashPos != std::string::npos) {
|
||||
delayMinMs = std::atoi(delayStr.substr(0, dashPos).c_str());
|
||||
delayMaxMs = std::atoi(delayStr.substr(dashPos + 1).c_str());
|
||||
} else {
|
||||
delayMinMs = 0;
|
||||
delayMaxMs = std::atoi(delayStr.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (version2.empty()) {
|
||||
version2 = version;
|
||||
}
|
||||
|
||||
// If --reflector-list provided, pick one at random
|
||||
if (!reflectorList.empty()) {
|
||||
std::vector<std::string> addrs;
|
||||
size_t pos = 0;
|
||||
while (pos < reflectorList.size()) {
|
||||
size_t next = reflectorList.find(',', pos);
|
||||
if (next == std::string::npos) next = reflectorList.size();
|
||||
std::string addr = reflectorList.substr(pos, next - pos);
|
||||
if (!addr.empty()) addrs.push_back(addr);
|
||||
pos = next + 1;
|
||||
}
|
||||
if (addrs.empty()) {
|
||||
fprintf(stderr, "Error: --reflector-list is empty\n");
|
||||
return 1;
|
||||
}
|
||||
std::random_device rd;
|
||||
std::mt19937 rng(rd());
|
||||
std::uniform_int_distribution<size_t> dist(0, addrs.size() - 1);
|
||||
reflectorAddr = addrs[dist(rng)];
|
||||
if (reflectorAddr.rfind(':') == std::string::npos) {
|
||||
std::uniform_int_distribution<int> portDist(596, 599);
|
||||
reflectorAddr += ":" + std::to_string(portDist(rng));
|
||||
}
|
||||
if (mode.empty()) mode = "reflector";
|
||||
}
|
||||
|
||||
// Validate --mode
|
||||
if (mode.empty()) {
|
||||
fprintf(stderr, "Error: --mode is required (p2p, reflector, group, or group-churn)\n");
|
||||
return 1;
|
||||
}
|
||||
if (mode != "p2p" && mode != "reflector" && mode != "group" && mode != "group-churn") {
|
||||
fprintf(stderr, "Error: --mode must be 'p2p', 'reflector', 'group', or 'group-churn'\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Group mode: dispatch to separate implementation
|
||||
if (mode == "group") {
|
||||
return runGroupMode(participants, referenceParticipants, duration, gQuiet, enableVideo, networkScenario);
|
||||
}
|
||||
if (mode == "group-churn") {
|
||||
return runGroupChurnMode(participants, referenceParticipants, duration, gQuiet, enableVideo, churnCycles);
|
||||
}
|
||||
if (mode == "reflector" && reflectorAddr.empty()) {
|
||||
fprintf(stderr, "Error: --reflector host:port is required with --mode reflector\n");
|
||||
return 1;
|
||||
}
|
||||
if (mode == "p2p" && !reflectorAddr.empty()) {
|
||||
fprintf(stderr, "Error: --reflector cannot be used with --mode p2p\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Parse reflector address
|
||||
std::string reflectorHost;
|
||||
uint16_t reflectorPort = 0;
|
||||
if (mode == "reflector") {
|
||||
auto colonPos = reflectorAddr.rfind(':');
|
||||
if (colonPos == std::string::npos) {
|
||||
fprintf(stderr, "Error: --reflector must be in host:port format\n");
|
||||
return 1;
|
||||
}
|
||||
reflectorHost = reflectorAddr.substr(0, colonPos);
|
||||
reflectorPort = static_cast<uint16_t>(std::atoi(reflectorAddr.substr(colonPos + 1).c_str()));
|
||||
if (reflectorPort == 0) {
|
||||
fprintf(stderr, "Error: invalid reflector port\n");
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Generate peer tags for reflector mode
|
||||
std::array<uint8_t, 16> callerPeerTag{};
|
||||
std::array<uint8_t, 16> calleePeerTag{};
|
||||
if (mode == "reflector") {
|
||||
std::random_device rd;
|
||||
std::mt19937 rng(rd());
|
||||
std::uniform_int_distribution<int> dist(0, 255);
|
||||
for (auto& b : callerPeerTag) {
|
||||
b = static_cast<uint8_t>(dist(rng));
|
||||
}
|
||||
calleePeerTag = callerPeerTag;
|
||||
callerPeerTag[0] = 0x00;
|
||||
calleePeerTag[0] = 0x01;
|
||||
}
|
||||
|
||||
// Register implementations
|
||||
tgcalls::Register<tgcalls::InstanceV2Impl>();
|
||||
tgcalls::Register<tgcalls::InstanceV2CompatImpl>();
|
||||
tgcalls::Register<tgcalls::InstanceV2ReferenceImpl>();
|
||||
|
||||
// Create shared encryption key
|
||||
auto keyData = std::make_shared<std::array<uint8_t, 256>>();
|
||||
{
|
||||
std::mt19937 rng(42);
|
||||
std::uniform_int_distribution<int> dist(0, 255);
|
||||
for (auto& b : *keyData) {
|
||||
b = static_cast<uint8_t>(dist(rng));
|
||||
}
|
||||
}
|
||||
|
||||
// Bridge and state
|
||||
auto bridge = std::make_shared<SignalingBridge>();
|
||||
bridge->dropRate = dropRate;
|
||||
bridge->delayMinMs = delayMinMs;
|
||||
bridge->delayMaxMs = delayMaxMs;
|
||||
auto callState = std::make_shared<CallState>();
|
||||
|
||||
// Audio components
|
||||
auto callerRecorder = std::make_shared<SineRecorder>();
|
||||
auto callerRenderer = std::make_shared<NoOpRenderer>();
|
||||
auto calleeRecorder = std::make_shared<SineRecorder>();
|
||||
auto calleeRenderer = std::make_shared<NoOpRenderer>();
|
||||
|
||||
// Stats log paths (per-process to avoid collisions in parallel runs)
|
||||
std::string callerStatsPath = "/tmp/tgcalls_cli_caller_" + std::to_string(getpid()) + ".json";
|
||||
std::string calleeStatsPath = "/tmp/tgcalls_cli_callee_" + std::to_string(getpid()) + ".json";
|
||||
|
||||
// --- Caller descriptor ---
|
||||
auto callerDesc = (tgcalls::Descriptor){
|
||||
.version = version,
|
||||
.config = {
|
||||
.initializationTimeout = 10.0,
|
||||
.receiveTimeout = 10.0,
|
||||
.enableP2P = (mode == "p2p"),
|
||||
.statsLogPath = {callerStatsPath},
|
||||
},
|
||||
.rtcServers = (mode == "reflector")
|
||||
? std::vector<tgcalls::RtcServer>{makeReflectorServer(reflectorHost, reflectorPort, callerPeerTag)}
|
||||
: std::vector<tgcalls::RtcServer>{},
|
||||
.encryptionKey = tgcalls::EncryptionKey(keyData, true),
|
||||
.stateUpdated = [callState](tgcalls::State state) {
|
||||
logMsg("Caller", "state -> %s", stateName(state));
|
||||
std::lock_guard<std::mutex> lock(callState->mutex);
|
||||
callState->callerState = state;
|
||||
if (state == tgcalls::State::Established && callState->establishedAt < 0) {
|
||||
callState->establishedAt = elapsed();
|
||||
}
|
||||
if (state == tgcalls::State::Failed) {
|
||||
callState->errors.push_back("Caller entered Failed state");
|
||||
}
|
||||
},
|
||||
.signalingDataEmitted = [bridge](const std::vector<uint8_t>& data) {
|
||||
logMsg("Caller", "signaling data emitted (%zu bytes)", data.size());
|
||||
std::lock_guard<std::mutex> lock(bridge->mutex);
|
||||
bridge->deliver("Caller", data, bridge->callee);
|
||||
},
|
||||
.createAudioDeviceModule = tgcalls::FakeAudioDeviceModule::Creator(
|
||||
callerRenderer, callerRecorder,
|
||||
tgcalls::FakeAudioDeviceModule::Options{.samples_per_sec = 48000, .num_channels = 2}
|
||||
),
|
||||
};
|
||||
|
||||
// --- Callee descriptor ---
|
||||
auto calleeDesc = (tgcalls::Descriptor){
|
||||
.version = version2,
|
||||
.config = {
|
||||
.initializationTimeout = 10.0,
|
||||
.receiveTimeout = 10.0,
|
||||
.enableP2P = (mode == "p2p"),
|
||||
.statsLogPath = {calleeStatsPath},
|
||||
},
|
||||
.rtcServers = (mode == "reflector")
|
||||
? std::vector<tgcalls::RtcServer>{makeReflectorServer(reflectorHost, reflectorPort, calleePeerTag)}
|
||||
: std::vector<tgcalls::RtcServer>{},
|
||||
.encryptionKey = tgcalls::EncryptionKey(keyData, false),
|
||||
.stateUpdated = [callState](tgcalls::State state) {
|
||||
logMsg("Callee", "state -> %s", stateName(state));
|
||||
std::lock_guard<std::mutex> lock(callState->mutex);
|
||||
callState->calleeState = state;
|
||||
if (state == tgcalls::State::Established && callState->establishedAt < 0) {
|
||||
callState->establishedAt = elapsed();
|
||||
}
|
||||
if (state == tgcalls::State::Failed) {
|
||||
callState->errors.push_back("Callee entered Failed state");
|
||||
}
|
||||
},
|
||||
.signalingDataEmitted = [bridge](const std::vector<uint8_t>& data) {
|
||||
logMsg("Callee", "signaling data emitted (%zu bytes)", data.size());
|
||||
std::lock_guard<std::mutex> lock(bridge->mutex);
|
||||
bridge->deliver("Callee", data, bridge->caller);
|
||||
},
|
||||
.createAudioDeviceModule = tgcalls::FakeAudioDeviceModule::Creator(
|
||||
calleeRenderer, calleeRecorder,
|
||||
tgcalls::FakeAudioDeviceModule::Options{.samples_per_sec = 48000, .num_channels = 2}
|
||||
),
|
||||
};
|
||||
|
||||
// Create instances
|
||||
auto callerInstance = std::shared_ptr<tgcalls::Instance>(
|
||||
tgcalls::Meta::Create(version, std::move(callerDesc)).release());
|
||||
if (!callerInstance) {
|
||||
fprintf(stderr, "Error: unknown version '%s'\n", version.c_str());
|
||||
return 1;
|
||||
}
|
||||
logMsg("Caller", "created (version %s)", version.c_str());
|
||||
|
||||
auto calleeInstance = std::shared_ptr<tgcalls::Instance>(
|
||||
tgcalls::Meta::Create(version2, std::move(calleeDesc)).release());
|
||||
if (!calleeInstance) {
|
||||
fprintf(stderr, "Error: unknown callee version '%s'\n", version2.c_str());
|
||||
return 1;
|
||||
}
|
||||
logMsg("Callee", "created (version %s)", version2.c_str());
|
||||
|
||||
// Wire bridge
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bridge->mutex);
|
||||
bridge->caller = callerInstance;
|
||||
bridge->callee = calleeInstance;
|
||||
}
|
||||
|
||||
logMsg("Main", "sleeping for %d seconds...", duration);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(duration));
|
||||
|
||||
// Stop both instances
|
||||
logMsg("Main", "stopping instances...");
|
||||
|
||||
std::atomic<int> stopCount{0};
|
||||
std::mutex stopMutex;
|
||||
std::condition_variable stopCv;
|
||||
|
||||
auto onStopped = [&](const char* role) {
|
||||
return [&, role](tgcalls::FinalState) {
|
||||
logMsg(role, "stopped");
|
||||
stopCount.fetch_add(1);
|
||||
std::lock_guard<std::mutex> lock(stopMutex);
|
||||
stopCv.notify_all();
|
||||
};
|
||||
};
|
||||
|
||||
callerInstance->stop(onStopped("Caller"));
|
||||
calleeInstance->stop(onStopped("Callee"));
|
||||
|
||||
// Wait for both stop callbacks (up to 5 seconds)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(stopMutex);
|
||||
stopCv.wait_for(lock, std::chrono::seconds(5), [&] {
|
||||
return stopCount.load() >= 2;
|
||||
});
|
||||
}
|
||||
|
||||
// Release instances — clear bridge first to prevent signaling during teardown
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bridge->mutex);
|
||||
bridge->caller.reset();
|
||||
bridge->callee.reset();
|
||||
}
|
||||
callerInstance.reset();
|
||||
calleeInstance.reset();
|
||||
|
||||
// Read stats logs: count bitrate records and check for non-zero BWE
|
||||
struct StatsResult {
|
||||
int bitrateRecords = 0;
|
||||
bool hasNonZeroBwe = false;
|
||||
};
|
||||
auto parseStatsLog = [](const std::string& path) -> StatsResult {
|
||||
StatsResult result;
|
||||
std::ifstream f(path);
|
||||
if (!f.is_open()) return result;
|
||||
std::string content((std::istreambuf_iterator<char>(f)),
|
||||
std::istreambuf_iterator<char>());
|
||||
size_t pos = 0;
|
||||
while ((pos = content.find("\"b\":", pos)) != std::string::npos) {
|
||||
pos += 4;
|
||||
result.bitrateRecords++;
|
||||
// Parse the integer value after "b":
|
||||
int val = std::atoi(content.c_str() + pos);
|
||||
if (val > 0) {
|
||||
result.hasNonZeroBwe = true;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
auto callerStats = parseStatsLog(callerStatsPath);
|
||||
auto calleeStats = parseStatsLog(calleeStatsPath);
|
||||
unlink(callerStatsPath.c_str());
|
||||
unlink(calleeStatsPath.c_str());
|
||||
|
||||
// Print summary
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(callState->mutex);
|
||||
|
||||
bool established = (callState->establishedAt >= 0);
|
||||
|
||||
printf("\n=== Call Summary ===\n");
|
||||
printf("Duration: %ds\n", duration);
|
||||
if (dropRate > 0.0 || delayMaxMs > 0) {
|
||||
printf("Signaling: drop=%.0f%% delay=%d-%dms\n",
|
||||
dropRate * 100.0, delayMinMs, delayMaxMs);
|
||||
}
|
||||
if (mode == "reflector") {
|
||||
printf("Mode: reflector (%s:%d)\n", reflectorHost.c_str(), reflectorPort);
|
||||
} else {
|
||||
printf("Mode: p2p\n");
|
||||
}
|
||||
printf("Caller state: %s\n", stateName(callState->callerState));
|
||||
printf("Callee state: %s\n", stateName(callState->calleeState));
|
||||
if (callState->establishedAt >= 0) {
|
||||
printf("Call established: yes (at %.3fs)\n", callState->establishedAt);
|
||||
} else {
|
||||
printf("Call established: no\n");
|
||||
}
|
||||
bool bweNonZero = callerStats.hasNonZeroBwe && calleeStats.hasNonZeroBwe;
|
||||
|
||||
printf("Stats log: caller=%d callee=%d bitrate records\n",
|
||||
callerStats.bitrateRecords, calleeStats.bitrateRecords);
|
||||
printf("BWE non-zero: %s\n", bweNonZero ? "yes" : "no");
|
||||
|
||||
bool statsCollected = (callerStats.bitrateRecords > 0 && calleeStats.bitrateRecords > 0);
|
||||
|
||||
if (callState->errors.empty()) {
|
||||
printf("Errors: none\n");
|
||||
} else {
|
||||
printf("Errors:\n");
|
||||
for (const auto& err : callState->errors) {
|
||||
printf(" - %s\n", err.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Use _exit() to skip static destruction. ThreadLocalObject's destructor
|
||||
// posts fire-and-forget cleanup tasks to the tgcalls media thread. If we
|
||||
// return normally, static destruction tears down the StaticThreads thread
|
||||
// pool while those tasks may still be executing, causing "pure virtual
|
||||
// function called" when a half-destroyed object's vtable is accessed.
|
||||
fflush(stdout);
|
||||
fflush(stderr);
|
||||
_exit(established && statsCollected && bweNonZero ? 0 : 1);
|
||||
}
|
||||
}
|
||||
Executable
+105
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env bash
|
||||
# Run N parallel P2P tests locally and report aggregate results.
|
||||
#
|
||||
# Usage:
|
||||
# ./run-local-test.sh # 100 calls, 15s each, 30% loss
|
||||
# ./run-local-test.sh -n 1000 # 1000 calls
|
||||
# ./run-local-test.sh -n 500 -j 200 # 500 calls, 200 parallel
|
||||
# ./run-local-test.sh -n 100 -d 30 # 100 calls, 30s each
|
||||
# ./run-local-test.sh --drop-rate 0.5 # 50% loss
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BINARY="./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli"
|
||||
NUM=100
|
||||
PARALLEL=150
|
||||
DURATION=15
|
||||
DROP_RATE=0.3
|
||||
DELAY="50-200"
|
||||
MODE="p2p"
|
||||
VERSION="13.0.0"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-n) NUM="$2"; shift 2 ;;
|
||||
-j) PARALLEL="$2"; shift 2 ;;
|
||||
-d) DURATION="$2"; shift 2 ;;
|
||||
--drop-rate) DROP_RATE="$2"; shift 2 ;;
|
||||
--delay) DELAY="$2"; shift 2 ;;
|
||||
--mode) MODE="$2"; shift 2 ;;
|
||||
--version) VERSION="$2"; shift 2 ;;
|
||||
*) echo "Usage: $0 [-n NUM] [-j PARALLEL] [-d DURATION] [--drop-rate RATE] [--delay MIN-MAX] [--mode MODE] [--version VER]"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ ! -x "$BINARY" ]; then
|
||||
echo "Binary not found: $BINARY"
|
||||
echo "Run: ./build-input/bazel-8.4.2 build //submodules/TgVoipWebrtc/tgcalls/tools/cli:tgcalls_cli"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
TMPDIR=$(mktemp -d)
|
||||
trap "rm -rf $TMPDIR" EXIT
|
||||
|
||||
echo "Running $NUM calls ($PARALLEL parallel, ${DURATION}s each, drop=${DROP_RATE}, delay=${DELAY}ms, mode=${MODE}, version=${VERSION})"
|
||||
|
||||
START=$(date +%s)
|
||||
launched=0
|
||||
wave=0
|
||||
|
||||
while [ $launched -lt $NUM ]; do
|
||||
wave=$((wave + 1))
|
||||
remaining=$((NUM - launched))
|
||||
batch=$((remaining > PARALLEL ? PARALLEL : remaining))
|
||||
|
||||
pids=()
|
||||
for i in $(seq 1 $batch); do
|
||||
id=$((launched + i))
|
||||
(
|
||||
if "$BINARY" --mode "$MODE" --duration "$DURATION" \
|
||||
--drop-rate "$DROP_RATE" --delay "$DELAY" --version "$VERSION" --quiet \
|
||||
> /dev/null 2>&1; then
|
||||
echo "pass" > "$TMPDIR/$id"
|
||||
else
|
||||
echo "fail" > "$TMPDIR/$id"
|
||||
fi
|
||||
) &
|
||||
pids+=($!)
|
||||
done
|
||||
|
||||
for pid in "${pids[@]}"; do
|
||||
wait "$pid" 2>/dev/null || true
|
||||
done
|
||||
|
||||
launched=$((launched + batch))
|
||||
echo " Wave $wave: $launched/$NUM done"
|
||||
done
|
||||
|
||||
END=$(date +%s)
|
||||
ELAPSED=$((END - START))
|
||||
|
||||
# Tally
|
||||
success=0
|
||||
failed=0
|
||||
for f in "$TMPDIR"/*; do
|
||||
[ -f "$f" ] || continue
|
||||
if [ "$(cat "$f")" = "pass" ]; then
|
||||
success=$((success + 1))
|
||||
else
|
||||
failed=$((failed + 1))
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=== Local Mass Test Results ==="
|
||||
echo "Total: $NUM"
|
||||
echo "Success: $success"
|
||||
echo "Failed: $failed"
|
||||
if [ $NUM -gt 0 ]; then
|
||||
rate=$(echo "scale=1; $success * 100 / $NUM" | bc)
|
||||
echo "Rate: ${rate}%"
|
||||
fi
|
||||
echo "Duration: ${ELAPSED}s"
|
||||
echo "Parallel: $PARALLEL"
|
||||
|
||||
exit 0
|
||||
Executable
+249
@@ -0,0 +1,249 @@
|
||||
#!/usr/bin/env bash
|
||||
# Launch N tgcalls test tasks on ECS Fargate, spread across reflectors.
|
||||
#
|
||||
# Usage:
|
||||
# ./run-test.sh # 10 tasks, 30s each
|
||||
# ./run-test.sh -n 100 # 100 tasks
|
||||
# ./run-test.sh -n 50 -d 60 # 50 tasks, 60s each
|
||||
# ./run-test.sh --results # fetch results from last run
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
CLUSTER="tgcalls-test"
|
||||
TASK_DEF="tgcalls-test"
|
||||
REGION="eu-west-1"
|
||||
SUBNETS="subnet-0292f49f3b4885428,subnet-09b8edab6eb20b837,subnet-0f464b5c62c9a6d1a"
|
||||
SECURITY_GROUP="sg-0d87a1f19be76c160"
|
||||
LOG_GROUP="/ecs/tgcalls-test"
|
||||
REFLECTOR_URL="https://core.telegram.org/getReflectorList"
|
||||
RUN_FILE="/tmp/tgcalls-last-run.txt"
|
||||
STATUS_FILE="/tmp/tgcalls-last-status.txt"
|
||||
|
||||
NUM_TASKS=10
|
||||
DURATION=30
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 [-n NUM_TASKS] [-d DURATION_SECS] [--results]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
fetch_results() {
|
||||
if [ ! -f "$RUN_FILE" ]; then
|
||||
echo "No run file found. Run a test first."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Fetching results from last run..."
|
||||
echo ""
|
||||
|
||||
TMPDIR_RESULTS=$(mktemp -d)
|
||||
RESULTS_PARALLEL=20
|
||||
total=$(wc -l < "$RUN_FILE" | tr -d ' ')
|
||||
fetched=0
|
||||
|
||||
# Fetch logs in parallel
|
||||
while IFS= read -r task_id; do
|
||||
stream="tgcalls/tgcalls/${task_id}"
|
||||
(aws logs get-log-events \
|
||||
--log-group-name "$LOG_GROUP" \
|
||||
--log-stream-name "$stream" \
|
||||
--region "$REGION" \
|
||||
--query 'events[*].message' \
|
||||
--output text > "${TMPDIR_RESULTS}/${task_id}" 2>/dev/null || true) &
|
||||
|
||||
fetched=$((fetched + 1))
|
||||
# Throttle: wait every RESULTS_PARALLEL calls
|
||||
if [ $((fetched % RESULTS_PARALLEL)) -eq 0 ]; then
|
||||
wait
|
||||
echo -ne " Fetched $fetched/$total\r"
|
||||
fi
|
||||
done < "$RUN_FILE"
|
||||
wait
|
||||
echo " Fetched $total/$total"
|
||||
echo ""
|
||||
|
||||
# Tally results
|
||||
success=0
|
||||
fail=0
|
||||
errors=""
|
||||
no_logs_tasks=()
|
||||
|
||||
for result_file in "${TMPDIR_RESULTS}"/*; do
|
||||
[ -f "$result_file" ] || continue
|
||||
task_id=$(basename "$result_file")
|
||||
output=$(cat "$result_file")
|
||||
|
||||
if [ -z "$output" ]; then
|
||||
# No logs yet — queue for retry
|
||||
no_logs_tasks+=("$task_id")
|
||||
elif echo "$output" | tr '\t' '\n' | grep -q "Audio received:.*yes" && echo "$output" | tr '\t' '\n' | grep -q "Call established:.*yes"; then
|
||||
success=$((success + 1))
|
||||
else
|
||||
fail=$((fail + 1))
|
||||
reflector=$(echo "$output" | tr '\t' '\n' | grep -o 'reflector ([^)]*' | sed 's/reflector (//' || echo "unknown")
|
||||
errors="${errors}\n ${task_id}: reflector=${reflector}"
|
||||
fi
|
||||
done
|
||||
|
||||
rm -rf "$TMPDIR_RESULTS"
|
||||
|
||||
# Retry tasks that had no logs
|
||||
if [ ${#no_logs_tasks[@]} -gt 0 ]; then
|
||||
echo "Retrying ${#no_logs_tasks[@]} tasks with missing logs..."
|
||||
sleep 5
|
||||
for task_id in "${no_logs_tasks[@]}"; do
|
||||
stream="tgcalls/tgcalls/${task_id}"
|
||||
output=$(aws logs get-log-events \
|
||||
--log-group-name "$LOG_GROUP" \
|
||||
--log-stream-name "$stream" \
|
||||
--region "$REGION" \
|
||||
--query 'events[*].message' \
|
||||
--output text 2>/dev/null || true)
|
||||
|
||||
if [ -n "$output" ] && echo "$output" | tr '\t' '\n' | grep -q "Audio received:.*yes" && echo "$output" | tr '\t' '\n' | grep -q "Call established:.*yes"; then
|
||||
success=$((success + 1))
|
||||
else
|
||||
fail=$((fail + 1))
|
||||
ecs_info=""
|
||||
if [ -f "$STATUS_FILE" ]; then
|
||||
ecs_info=$(grep "^${task_id}" "$STATUS_FILE" | head -1 | cut -f2-3)
|
||||
fi
|
||||
if [ -n "$ecs_info" ]; then
|
||||
errors="${errors}\n ${task_id}: exit=${ecs_info}"
|
||||
else
|
||||
errors="${errors}\n ${task_id} (no logs, no ECS status)"
|
||||
fi
|
||||
fi
|
||||
done
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo "=== Test Results ==="
|
||||
echo "Total tasks: $total"
|
||||
echo "Success: $success"
|
||||
echo "Failed: $fail"
|
||||
if [ -n "$errors" ]; then
|
||||
echo -e "\nFailed tasks:${errors}"
|
||||
fi
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Parse args
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-n) NUM_TASKS="$2"; shift 2 ;;
|
||||
-d) DURATION="$2"; shift 2 ;;
|
||||
--results) fetch_results ;;
|
||||
*) usage ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Fetch reflector list — IPs only (port randomized by CLI)
|
||||
echo "Fetching reflector list..."
|
||||
REFLECTOR_CSV=$(curl -s "$REFLECTOR_URL" | cut -d: -f1 | sort -u | tr '\n' ',' | sed 's/,$//')
|
||||
NUM_REFLECTORS=$(echo "$REFLECTOR_CSV" | tr ',' '\n' | wc -l | tr -d ' ')
|
||||
echo "Got $NUM_REFLECTORS unique reflector IPs"
|
||||
|
||||
# To inject bad addresses for testing, uncomment:
|
||||
# NUM_BAD=$(( NUM_REFLECTORS / 9 ))
|
||||
# BAD_CSV=$(for i in $(seq 1 $NUM_BAD); do echo -n "10.255.255.$((i % 256)):1,"; done | sed 's/,$//')
|
||||
# REFLECTOR_CSV="${REFLECTOR_CSV},${BAD_CSV}"
|
||||
# echo "Injected $NUM_BAD bad addresses (~10% of pool)"
|
||||
|
||||
echo "Launching $NUM_TASKS tasks (${DURATION}s each), each picks a random reflector..."
|
||||
echo ""
|
||||
|
||||
# Clear run files
|
||||
> "$RUN_FILE"
|
||||
> "$STATUS_FILE"
|
||||
|
||||
# Launch in waves of WAVE_SIZE, waiting for each wave to complete before the next.
|
||||
# Within each wave, fire PARALLEL API calls concurrently (each launching up to 10 tasks).
|
||||
WAVE_SIZE=500
|
||||
PARALLEL=10
|
||||
TMPDIR_LAUNCH=$(mktemp -d)
|
||||
remaining=$NUM_TASKS
|
||||
wave=0
|
||||
|
||||
while [ $remaining -gt 0 ]; do
|
||||
wave=$((wave + 1))
|
||||
wave_target=$((remaining > WAVE_SIZE ? WAVE_SIZE : remaining))
|
||||
wave_arns=()
|
||||
wave_launched=0
|
||||
|
||||
echo "=== Wave $wave: launching $wave_target tasks ==="
|
||||
|
||||
while [ $wave_launched -lt $wave_target ]; do
|
||||
pids=()
|
||||
api_calls=0
|
||||
for p in $(seq 1 $PARALLEL); do
|
||||
left=$((wave_target - wave_launched - api_calls * 10))
|
||||
[ $left -le 0 ] && break
|
||||
batch=$((left > 10 ? 10 : left))
|
||||
outfile="${TMPDIR_LAUNCH}/batch_${wave}_${wave_launched}_${p}"
|
||||
api_calls=$((api_calls + 1))
|
||||
|
||||
(aws ecs run-task --region "$REGION" \
|
||||
--cluster "$CLUSTER" \
|
||||
--task-definition "$TASK_DEF" \
|
||||
--launch-type FARGATE \
|
||||
--count "$batch" \
|
||||
--network-configuration "awsvpcConfiguration={subnets=[${SUBNETS}],securityGroups=[${SECURITY_GROUP}],assignPublicIp=ENABLED}" \
|
||||
--overrides "{\"containerOverrides\":[{\"name\":\"tgcalls\",\"command\":[\"--quiet\",\"--reflector-list\",\"${REFLECTOR_CSV}\",\"--duration\",\"${DURATION}\",\"--drop-rate\",\"0.3\",\"--delay\",\"50-200\"]}]}" \
|
||||
--query 'tasks[*].taskArn' --output text > "$outfile" 2>&1) &
|
||||
pids+=($!)
|
||||
done
|
||||
|
||||
for pid in "${pids[@]}"; do
|
||||
wait "$pid" 2>/dev/null || true
|
||||
done
|
||||
|
||||
for outfile in "${TMPDIR_LAUNCH}"/batch_*; do
|
||||
[ -f "$outfile" ] || continue
|
||||
while read -r arn; do
|
||||
if [[ "$arn" == arn:* ]]; then
|
||||
task_id="${arn##*/}"
|
||||
wave_arns+=("$arn")
|
||||
echo "$task_id" >> "$RUN_FILE"
|
||||
wave_launched=$((wave_launched + 1))
|
||||
fi
|
||||
done < <(tr '\t' '\n' < "$outfile")
|
||||
rm -f "$outfile"
|
||||
done
|
||||
|
||||
echo " Launched $wave_launched/$wave_target in wave $wave"
|
||||
done
|
||||
|
||||
remaining=$((remaining - wave_launched))
|
||||
echo " Waiting for wave $wave ($wave_launched tasks) to finish..."
|
||||
|
||||
# Wait in batches of 100
|
||||
for ((start=0; start<${#wave_arns[@]}; start+=100)); do
|
||||
batch=("${wave_arns[@]:$start:100}")
|
||||
aws ecs wait tasks-stopped \
|
||||
--cluster "$CLUSTER" \
|
||||
--tasks "${batch[@]}" \
|
||||
--region "$REGION" 2>/dev/null || true
|
||||
done
|
||||
|
||||
# Collect ECS task status while data is fresh (expires after ~1hr)
|
||||
echo " Collecting task status for wave $wave..."
|
||||
for ((start=0; start<${#wave_arns[@]}; start+=100)); do
|
||||
batch=("${wave_arns[@]:$start:100}")
|
||||
aws ecs describe-tasks --cluster "$CLUSTER" --tasks "${batch[@]}" --region "$REGION" \
|
||||
--query 'tasks[*].[containers[0].taskArn,containers[0].exitCode,stoppedReason]' \
|
||||
--output text 2>/dev/null | while IFS=$'\t' read -r arn exit_code reason; do
|
||||
task_id="${arn##*/}"
|
||||
echo -e "${task_id}\t${exit_code}\t${reason}" >> "$STATUS_FILE"
|
||||
done
|
||||
done
|
||||
|
||||
echo " Wave $wave complete."
|
||||
echo ""
|
||||
done
|
||||
|
||||
rm -rf "$TMPDIR_LAUNCH"
|
||||
total_launched=$(wc -l < "$RUN_FILE" | tr -d ' ')
|
||||
|
||||
echo "Launched $total_launched/$NUM_TASKS total tasks."
|
||||
echo "Run '$0 --results' to see results."
|
||||
Executable
+93
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env bash
|
||||
# Run parallel tests, stop on first crash (non-zero exit).
|
||||
set -euo pipefail
|
||||
|
||||
BINARY="./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli"
|
||||
PARALLEL=250
|
||||
DURATION=15
|
||||
VERSION="11.0.0"
|
||||
DROP_RATE=0.3
|
||||
DELAY="50-200"
|
||||
MODE="p2p"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-j) PARALLEL="$2"; shift 2 ;;
|
||||
-d) DURATION="$2"; shift 2 ;;
|
||||
--version) VERSION="$2"; shift 2 ;;
|
||||
--drop-rate) DROP_RATE="$2"; shift 2 ;;
|
||||
--delay) DELAY="$2"; shift 2 ;;
|
||||
--mode) MODE="$2"; shift 2 ;;
|
||||
*) echo "Usage: $0 [-j PARALLEL] [-d DURATION] [--version VER] [--drop-rate R] [--delay D] [--mode M]"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ ! -x "$BINARY" ]; then
|
||||
echo "Binary not found: $BINARY"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
TMPDIR=$(mktemp -d)
|
||||
trap "rm -rf $TMPDIR" EXIT
|
||||
|
||||
echo "Running waves of $PARALLEL until first crash (${DURATION}s, drop=${DROP_RATE}, delay=${DELAY}, version=${VERSION})"
|
||||
|
||||
wave=0
|
||||
total=0
|
||||
while true; do
|
||||
wave=$((wave + 1))
|
||||
pids=()
|
||||
for i in $(seq 1 $PARALLEL); do
|
||||
id=$((total + i))
|
||||
(
|
||||
set +e
|
||||
"$BINARY" --mode "$MODE" --duration "$DURATION" \
|
||||
--drop-rate "$DROP_RATE" --delay "$DELAY" --version "$VERSION" --quiet \
|
||||
> "$TMPDIR/${id}.out" 2>"$TMPDIR/${id}.err"
|
||||
echo $? > "$TMPDIR/${id}.rc"
|
||||
) &
|
||||
pids+=($!)
|
||||
done
|
||||
|
||||
for pid in "${pids[@]}"; do
|
||||
wait "$pid" 2>/dev/null || true
|
||||
done
|
||||
|
||||
total=$((total + PARALLEL))
|
||||
|
||||
# Check for crashes
|
||||
crashes=0
|
||||
for i in $(seq $((total - PARALLEL + 1)) $total); do
|
||||
rc_file="$TMPDIR/${i}.rc"
|
||||
if [ ! -f "$rc_file" ]; then
|
||||
crashes=$((crashes + 1))
|
||||
echo ""
|
||||
echo "=== CRASH in run $i (no rc file) ==="
|
||||
echo "--- stderr ---"
|
||||
cat "$TMPDIR/${i}.err" 2>/dev/null || echo "(empty)"
|
||||
else
|
||||
rc=$(cat "$rc_file")
|
||||
if [ "$rc" -gt 128 ] 2>/dev/null; then
|
||||
crashes=$((crashes + 1))
|
||||
echo ""
|
||||
echo "=== CRASH in run $i (exit $rc) ==="
|
||||
echo "--- stderr ---"
|
||||
cat "$TMPDIR/${i}.err" 2>/dev/null || echo "(empty)"
|
||||
echo "--- stdout ---"
|
||||
cat "$TMPDIR/${i}.out" 2>/dev/null || echo "(empty)"
|
||||
# Only show first crash in detail
|
||||
if [ $crashes -eq 1 ]; then
|
||||
echo "=== END CRASH ==="
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $crashes -gt 0 ]; then
|
||||
echo ""
|
||||
echo "Wave $wave: $crashes crashes in $PARALLEL runs (total $total runs)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo " Wave $wave: $PARALLEL/$PARALLEL passed (total $total)"
|
||||
done
|
||||
@@ -0,0 +1,18 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_binary")
|
||||
|
||||
go_binary(
|
||||
name = "go_sfu",
|
||||
srcs = glob(["*.go"]),
|
||||
cgo = True,
|
||||
linkmode = "c-archive",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@com_github_pion_datachannel//:datachannel",
|
||||
"@com_github_pion_dtls_v3//:dtls",
|
||||
"@com_github_pion_ice_v4//:ice",
|
||||
"@com_github_pion_logging//:logging",
|
||||
"@com_github_pion_rtcp//:rtcp",
|
||||
"@com_github_pion_sctp//:sctp",
|
||||
"@com_github_pion_srtp_v3//:srtp",
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
# Go/Pion SFU
|
||||
|
||||
The group call test mode uses a Go-based SFU (Selective Forwarding Unit) built with [Pion WebRTC](https://github.com/pion/webrtc), linked into the C++ `tgcalls_cli` binary via CGo.
|
||||
|
||||
## Build Integration
|
||||
- `MODULE.bazel` — `rules_go` 0.60.0 + Go SDK 1.24.2 + `gazelle` 0.43.0; Pion dependencies managed via `go_deps` Gazelle extension
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/BUILD` — `go_binary` with `linkmode = "c-archive"` produces a static archive + CGo header, exposes `CcInfo` to C++ targets
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go.mod` / `go.sum` — Pion dependency declarations (pion/ice, pion/dtls, pion/srtp, pion/sctp)
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/BUILD` — depends on `//submodules/TgVoipWebrtc/tgcalls/tools/go_sfu` to link the Go archive
|
||||
- The CGo-generated header is included as `#include "submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go_sfu.h"` in C++ code
|
||||
|
||||
## How It Works
|
||||
The `go_binary` with `linkmode = "c-archive"` compiles Go code (including the Go runtime) into a `.a` static archive. Functions annotated with `//export` in Go become C-callable symbols. Bazel's `rules_go` automatically provides `CcInfo`, so `cc_binary` targets can depend on the Go archive via `deps` — no manual linkopts needed.
|
||||
|
||||
The Go runtime (GC, goroutine scheduler) runs inside the C++ process. This adds ~10MB memory overhead. `GoSfu_Init()` must be called before any other Go functions.
|
||||
|
||||
## Key Files
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/sfu.go` — SFU core: participant registry, join/leave/response flow, audio+video RTP forwarding, SSRC registry (audio/video/video-rtx with layer index), Colibri `ReceiverVideoConstraints`/`SenderVideoConstraints` handling, PLI/FIR forwarding, `ActiveAudioSsrcs`/`ActiveVideoSsrcs` broadcasting, `//export` C bindings (`GoSfu_Init`, `GoSfu_Create`, `GoSfu_Destroy`, `GoSfu_Join`, `GoSfu_Leave`, `GoSfu_QuerySsrc`, `GoSfu_QueryVideoSsrcs`, `GoSfu_Free`, `GoSfu_Shutdown`)
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/participant.go` — per-participant transport stack (ICE agent, DTLS conn, SRTP session, SRTCP contexts for manual RTCP decrypt/encrypt, SCTP association, data channel send/receive, per-receiver video layer selection)
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/mux.go` — packet demuxer: three-way split of ICE traffic into DTLS handshake, SRTP (RTP), and SRTCP (RTCP) channels per RFC 7983 + RFC 5761
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go.mod` / `go.sum` — Go module with Pion dependencies
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/group_mode.cpp` — C++ side that drives the group join flow and calls into Go SFU
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/group_mode.h` — header for group mode entry point
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/group_participant.h/.cpp` — shared participant lifecycle helpers (`createParticipant`, `stopParticipant`, `validateGroupState`, `printGroupSummary`), `ParticipantState` struct, audio helpers
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/group_churn_mode.h/.cpp` — group-churn stress test: base group + rapid join/leave cycling
|
||||
|
||||
## SFU Bandwidth Adaptation
|
||||
|
||||
The SFU implements REMB-based bandwidth-adaptive simulcast layer selection for video. Per receiver, it maintains an EWMA-smoothed bandwidth estimate from REMB RTCP feedback and uses a `LayerSelector` state machine per (receiver, sender) pair to decide which simulcast layer to forward.
|
||||
|
||||
### State Machine
|
||||
- **STABLE**: forwarding current layer. Checks for upswitch opportunity (REMB > threshold × 1.2) or downswitch need (REMB < threshold × 0.7).
|
||||
- **PROBING_UP**: ramping RTX padding from 0 to the gap between current and target layer bitrate over 2 seconds. Aborts if REMB drops; succeeds if REMB sustains.
|
||||
- **GRACE_DOWN**: REMB below downswitch threshold. Waits 500ms, then downswitches if not recovered. 5-second cooldown after any switch.
|
||||
|
||||
### Layer Thresholds
|
||||
| Layer | Nominal | Upswitch When | Downswitch When |
|
||||
|-------|---------|--------------|-----------------|
|
||||
| 0 | 60 kbps | (start) | (never) |
|
||||
| 1 | 110 kbps | BW > 132 kbps | BW < 77 kbps |
|
||||
| 2 | 900 kbps | BW > 1,080 kbps | BW < 630 kbps |
|
||||
|
||||
### Layer Selection and SSRC Rewriting
|
||||
The SFU forwards exactly one simulcast layer per (receiver, sender) pair. Before `ReceiverVideoConstraints` arrives, the SFU uses `requestedLayer` as the cap and forwards at `maxActiveLayer` (the highest layer the encoder actually produces). After constraints arrive, `ensureLayerSelector` sets `selectedLayer` clamped to `maxActiveLayer`.
|
||||
|
||||
When forwarding a non-base layer, the SFU rewrites the RTP SSRC to the primary (layer 0) SSRC. This is necessary because `IncomingVideoChannel` in CustomImpl attaches its `VideoSinkImpl` to `_mainVideoSsrc` (the first SSRC in the SIM group, i.e., layer 0). Without SSRC rewriting, packets from higher layers are delivered to the wrong receive stream and never decoded. RTX SSRCs are similarly rewritten to the layer 0 FID SSRC.
|
||||
|
||||
### Testing on Localhost
|
||||
Use `--network-scenario step-down-up` to exercise the full adaptation path via per-client network simulation (replaces the old REMB-override `--bw-scenario`).
|
||||
|
||||
```bash
|
||||
# Network scenario test (30s, 4 phases: uncapped → 80k egress → 200k → uncapped)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 2 --video --duration 30 --network-scenario step-down-up
|
||||
```
|
||||
|
||||
Unit tests drive the `LayerSelector` state machine directly via mocked callbacks. Run from `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/`:
|
||||
|
||||
```bash
|
||||
go test -run TestLayerSelector -v -timeout 60s
|
||||
```
|
||||
|
||||
Covers upswitch L0→L1→L2, downswitch L2→L1→L0, grace-down recovery on transient dips, stale-BW idle behavior, the `OnMaxActiveLayerIncreased` fallback used when clients don't send REMB, and `maxLayer` enforcement.
|
||||
|
||||
### REMB-free Fallback
|
||||
|
||||
Real tgcalls clients negotiate `goog-remb` but use transport-cc as the primary BWE signal, so no REMB actually arrives at the SFU. This means the REMB-driven state machine never enters `PROBING_UP` in live runs. `LayerSelector.OnMaxActiveLayerIncreased(maxActive)` is the fallback: when the sender starts producing a higher simulcast layer than previously seen AND the BW estimate is stale, the SFU immediately upshifts to the highest available layer (clamped by `maxLayer`). Called from `sfu.go`'s packet-forwarding path whenever `maxActiveLayer[senderID]` is bumped.
|
||||
|
||||
### Key Files
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/bandwidth.go` — `BandwidthEstimator`, `LayerSelector`, `RtxRingBuffer`, `OnMaxActiveLayerIncreased`
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/bandwidth_test.go` — unit tests for `LayerSelector` up/down transitions
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/participant.go` — REMB parsing in `readRTCPLoop()`, `selectedLayers`
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/sfu.go` — layer-filtered forwarding, `ensureLayerSelector`, SSRC rewriting, `maxActiveLayer` tracking
|
||||
|
||||
## SFU Transport-CC Feedback
|
||||
|
||||
The SFU generates RTCP transport-cc feedback (type 205, FMT 15) per sender every 100ms. This provides the sender's GCC (Google Congestion Control) with packet arrival data, enabling BWE ramp-up so the encoder produces higher simulcast layers.
|
||||
|
||||
The feedback reflects actual (or simulated) packet arrivals — if ingress network simulation drops packets, the feedback reports them as missing, causing the sender's GCC to reduce bitrate.
|
||||
|
||||
### How It Works
|
||||
1. Each incoming RTP packet is parsed for the transport-wide sequence number (header extension ID 3, one-byte RFC 5285 format)
|
||||
2. `TransportCCGenerator.RecordArrival(twccSeq)` records the arrival time
|
||||
3. Every 100ms, `emitFeedback()` builds an `rtcp.TransportLayerCC` packet with `PacketChunks` (run-length or status-vector encoding) and `RecvDeltas` (250µs units)
|
||||
4. The feedback is marshalled, encrypted via SRTCP, and sent to the sender
|
||||
5. The sender's `Call::Receiver::DeliverRtcpPacket()` feeds it to the GCC via `GroupNetworkManager::OnRtcpPacketReceived_n` → `_call->Receiver()->DeliverRtcpPacket()`
|
||||
|
||||
### Current Status
|
||||
Transport-cc feedback is working: the SFU records ~60-70 arrivals per first 100ms tick, generates feedback packets (32-128 bytes), and the sender receives them. The GCC ramps from the 400kbps start bitrate to produce layer 1 (640x360). Full ramp to layer 2 (1280x720, needs ~1Mbps) requires further investigation — the GCC may need probing support or the `adjustBitratePreferences` max_bitrate_bps of 1052kbps may be a bottleneck.
|
||||
|
||||
### Key Files
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/twcc.go` — `TransportCCGenerator`, `parseTWCCSeq` (RTP header extension parser)
|
||||
|
||||
## SFU Network Simulation
|
||||
|
||||
Per-client network simulation with independent ingress (from client) and egress (to client) simulators. Each direction has: delay, jitter, packet loss, and bandwidth cap (token bucket).
|
||||
|
||||
```bash
|
||||
# Configure via CGo: GoSfu_SetNetworkParams(handle, participantID, direction, delayMs, jitterMs, dropRate, bandwidthBps)
|
||||
# direction: 0 = ingress, 1 = egress
|
||||
|
||||
# Network scenario test (4 phases: uncapped -> 80k -> 200k -> uncapped)
|
||||
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 2 --video --duration 30 --network-scenario step-down-up
|
||||
```
|
||||
|
||||
### Key Files
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/network_sim.go` — `NetworkSimulator` (token bucket, delay, jitter, drop)
|
||||
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/participant.go` — `ingressSim`, `egressSim` on each `Participant`
|
||||
@@ -0,0 +1,475 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Bandwidth Estimation ---
|
||||
|
||||
const (
|
||||
ewmaAlpha = 0.3
|
||||
safetyFactor = 0.85
|
||||
stalenessTTL = 5 * time.Second
|
||||
)
|
||||
|
||||
// BandwidthEstimator maintains an EWMA-smoothed REMB estimate for a receiver.
|
||||
type BandwidthEstimator struct {
|
||||
mu sync.Mutex
|
||||
lastREMBBps float64
|
||||
smoothedBps float64
|
||||
lastREMBAt time.Time
|
||||
}
|
||||
|
||||
// OnREMB feeds a new REMB value (in bits per second) into the estimator.
|
||||
func (e *BandwidthEstimator) OnREMB(bps float64) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.lastREMBBps = bps
|
||||
e.lastREMBAt = time.Now()
|
||||
if e.smoothedBps == 0 {
|
||||
e.smoothedBps = bps
|
||||
} else {
|
||||
e.smoothedBps = ewmaAlpha*bps + (1-ewmaAlpha)*e.smoothedBps
|
||||
}
|
||||
}
|
||||
|
||||
// EffectiveBps returns the safe bandwidth estimate in bps.
|
||||
// Returns -1 if the estimate is stale (no REMB for stalenessTTL).
|
||||
func (e *BandwidthEstimator) EffectiveBps() float64 {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
if e.lastREMBAt.IsZero() || time.Since(e.lastREMBAt) > stalenessTTL {
|
||||
return -1
|
||||
}
|
||||
return e.smoothedBps * safetyFactor
|
||||
}
|
||||
|
||||
// SmoothedBps returns the raw EWMA value (for logging).
|
||||
func (e *BandwidthEstimator) SmoothedBps() float64 {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.smoothedBps
|
||||
}
|
||||
|
||||
// LastREMBBps returns the last raw REMB value (for logging).
|
||||
func (e *BandwidthEstimator) LastREMBBps() float64 {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.lastREMBBps
|
||||
}
|
||||
|
||||
// --- Layer Bitrate Model ---
|
||||
|
||||
// LayerBitrate holds the thresholds for one simulcast layer.
|
||||
type LayerBitrate struct {
|
||||
Nominal float64 // typical sustained bitrate (bps)
|
||||
UpThresh float64 // effective BW must exceed this to upswitch TO this layer
|
||||
DownThresh float64 // effective BW must drop below this to downswitch FROM this layer
|
||||
}
|
||||
|
||||
// layerBitrates defines the 3 simulcast layers matching tgcalls adjustVideoSendParams().
|
||||
// Layer 0 has no downThresh (always viable) and no upThresh (start here).
|
||||
var layerBitrates = [3]LayerBitrate{
|
||||
{Nominal: 60_000, UpThresh: 0, DownThresh: 0}, // layer 0: 160x90
|
||||
{Nominal: 110_000, UpThresh: 132_000, DownThresh: 77_000}, // layer 1: 320x180
|
||||
{Nominal: 900_000, UpThresh: 1_080_000, DownThresh: 630_000}, // layer 2: 640x360
|
||||
}
|
||||
|
||||
// --- RTX Ring Buffer ---
|
||||
|
||||
// RtxEntry stores one video RTP packet for potential retransmission as RTX padding.
|
||||
type RtxEntry struct {
|
||||
Payload []byte
|
||||
SeqNum uint16
|
||||
Timestamp uint32
|
||||
}
|
||||
|
||||
// RtxRingBuffer is a per-sender circular buffer of recent video RTP packets.
|
||||
type RtxRingBuffer struct {
|
||||
mu sync.Mutex
|
||||
entries []RtxEntry
|
||||
head int
|
||||
count int
|
||||
cap int
|
||||
}
|
||||
|
||||
// NewRtxRingBuffer creates a ring buffer with the given capacity.
|
||||
func NewRtxRingBuffer(capacity int) *RtxRingBuffer {
|
||||
return &RtxRingBuffer{
|
||||
entries: make([]RtxEntry, capacity),
|
||||
cap: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
// Push adds a video RTP packet to the ring buffer.
|
||||
// payload is copied so the caller can reuse their buffer.
|
||||
func (r *RtxRingBuffer) Push(payload []byte, seqNum uint16, timestamp uint32) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
entry := &r.entries[r.head]
|
||||
if cap(entry.Payload) >= len(payload) {
|
||||
entry.Payload = entry.Payload[:len(payload)]
|
||||
} else {
|
||||
entry.Payload = make([]byte, len(payload))
|
||||
}
|
||||
copy(entry.Payload, payload)
|
||||
entry.SeqNum = seqNum
|
||||
entry.Timestamp = timestamp
|
||||
r.head = (r.head + 1) % r.cap
|
||||
if r.count < r.cap {
|
||||
r.count++
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns up to n most recent packets (oldest first).
|
||||
func (r *RtxRingBuffer) Get(n int) []RtxEntry {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if n > r.count {
|
||||
n = r.count
|
||||
}
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]RtxEntry, n)
|
||||
start := (r.head - r.count + r.cap) % r.cap // oldest entry
|
||||
readFrom := (start + r.count - n + r.cap) % r.cap // start of the n most recent
|
||||
for i := 0; i < n; i++ {
|
||||
idx := (readFrom + i) % r.cap
|
||||
src := &r.entries[idx]
|
||||
entry := RtxEntry{
|
||||
Payload: make([]byte, len(src.Payload)),
|
||||
SeqNum: src.SeqNum,
|
||||
Timestamp: src.Timestamp,
|
||||
}
|
||||
copy(entry.Payload, src.Payload)
|
||||
result[i] = entry
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// rtxEncapsulate wraps an original RTP payload into an RTX packet payload per RFC 4588.
|
||||
// The RTX payload is: [2-byte original sequence number] + [original RTP payload (after header)].
|
||||
// The caller is responsible for setting the RTX SSRC and incrementing RTX sequence number
|
||||
// on the outer RTP header.
|
||||
func rtxEncapsulate(originalPayload []byte, originalSeqNum uint16) []byte {
|
||||
out := make([]byte, 2+len(originalPayload))
|
||||
out[0] = byte(originalSeqNum >> 8)
|
||||
out[1] = byte(originalSeqNum)
|
||||
copy(out[2:], originalPayload)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Layer Selector State Machine ---
|
||||
|
||||
type selectorState int
|
||||
|
||||
const (
|
||||
stateStable selectorState = iota
|
||||
stateProbingUp
|
||||
stateGraceDown
|
||||
)
|
||||
|
||||
func (s selectorState) String() string {
|
||||
switch s {
|
||||
case stateStable:
|
||||
return "STABLE"
|
||||
case stateProbingUp:
|
||||
return "PROBING_UP"
|
||||
case stateGraceDown:
|
||||
return "GRACE_DOWN"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
probeDuration = 2 * time.Second
|
||||
graceDownTimeout = 500 * time.Millisecond
|
||||
cooldownDuration = 5 * time.Second
|
||||
tickInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
// LayerSelectorCallbacks provides the hooks the state machine needs into the SFU.
|
||||
type LayerSelectorCallbacks struct {
|
||||
// GetEffectiveBW returns the receiver's current effective bandwidth (bps), or -1 if stale.
|
||||
GetEffectiveBW func() float64
|
||||
// SetSelectedLayer updates the forwarding layer for this (receiver, sender) pair.
|
||||
SetSelectedLayer func(layer int)
|
||||
// SendPLI sends a PLI to the sender for the given SSRC.
|
||||
SendPLI func(ssrc uint32)
|
||||
// GetSenderVideoLayers returns the sender's simulcast layers.
|
||||
GetSenderVideoLayers func() []SimulcastLayer
|
||||
// GetRtxBuffer returns the sender's RTX ring buffer.
|
||||
GetRtxBuffer func() *RtxRingBuffer
|
||||
// SendRtxPadding sends an RTX padding packet to the receiver.
|
||||
// rtxSSRC is the FID SSRC, seqNum is the RTX sequence number.
|
||||
SendRtxPadding func(rtxPayload []byte, rtxSSRC uint32, seqNum uint16, timestamp uint32)
|
||||
// Log emits a log message.
|
||||
Log func(level string, format string, args ...interface{})
|
||||
}
|
||||
|
||||
// LayerSelector manages the state machine for one (receiver, sender) pair.
|
||||
type LayerSelector struct {
|
||||
mu sync.Mutex
|
||||
receiverID int
|
||||
senderID int
|
||||
currentLayer int
|
||||
maxLayer int // max layer the receiver requested
|
||||
state selectorState
|
||||
callbacks LayerSelectorCallbacks
|
||||
|
||||
// Probing state
|
||||
probeTarget int // layer we're probing toward
|
||||
probeStartTime time.Time
|
||||
probeRtxSeq uint16 // incrementing RTX sequence number for padding
|
||||
|
||||
// Grace-down state
|
||||
graceStartTime time.Time
|
||||
|
||||
// Cooldown
|
||||
lastSwitchTime time.Time
|
||||
|
||||
// Control
|
||||
stopCh chan struct{}
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewLayerSelector creates and starts a new LayerSelector.
|
||||
// initialLayer is the layer to start forwarding (typically = requestedLayer).
|
||||
func NewLayerSelector(receiverID, senderID, initialLayer, maxLayer int, cb LayerSelectorCallbacks) *LayerSelector {
|
||||
ls := &LayerSelector{
|
||||
receiverID: receiverID,
|
||||
senderID: senderID,
|
||||
currentLayer: initialLayer,
|
||||
maxLayer: maxLayer,
|
||||
state: stateStable,
|
||||
callbacks: cb,
|
||||
stopCh: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go ls.run()
|
||||
return ls
|
||||
}
|
||||
|
||||
// Stop terminates the selector's tick loop.
|
||||
func (ls *LayerSelector) Stop() {
|
||||
close(ls.stopCh)
|
||||
<-ls.done
|
||||
}
|
||||
|
||||
// SetMaxLayer updates the maximum layer the receiver wants (from ReceiverVideoConstraints).
|
||||
func (ls *LayerSelector) SetMaxLayer(maxLayer int) {
|
||||
ls.mu.Lock()
|
||||
defer ls.mu.Unlock()
|
||||
ls.maxLayer = maxLayer
|
||||
// If current layer exceeds new max, downswitch immediately.
|
||||
if ls.currentLayer > maxLayer {
|
||||
ls.switchLayer(maxLayer)
|
||||
}
|
||||
}
|
||||
|
||||
// OnMaxActiveLayerIncreased is called when the sender starts producing a
|
||||
// higher simulcast layer than previously observed. If the BW estimate is
|
||||
// stale (no REMB arriving — common when clients use transport-cc exclusively
|
||||
// and the SFU hasn't generated REMB), upshift immediately up to maxLayer so
|
||||
// the receiver gets the best available layer. When REMB is fresh, the state
|
||||
// machine is in charge and this is a no-op.
|
||||
func (ls *LayerSelector) OnMaxActiveLayerIncreased(maxActive int) {
|
||||
ls.mu.Lock()
|
||||
defer ls.mu.Unlock()
|
||||
if ls.callbacks.GetEffectiveBW() >= 0 {
|
||||
// BW estimate available — state machine decides.
|
||||
return
|
||||
}
|
||||
target := maxActive
|
||||
if target > ls.maxLayer {
|
||||
target = ls.maxLayer
|
||||
}
|
||||
if target > ls.currentLayer {
|
||||
ls.switchLayer(target)
|
||||
}
|
||||
}
|
||||
|
||||
// CurrentLayer returns the currently selected layer.
|
||||
func (ls *LayerSelector) CurrentLayer() int {
|
||||
ls.mu.Lock()
|
||||
defer ls.mu.Unlock()
|
||||
return ls.currentLayer
|
||||
}
|
||||
|
||||
func (ls *LayerSelector) run() {
|
||||
defer close(ls.done)
|
||||
ticker := time.NewTicker(tickInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ls.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
ls.tick()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *LayerSelector) tick() {
|
||||
ls.mu.Lock()
|
||||
defer ls.mu.Unlock()
|
||||
|
||||
effectiveBW := ls.callbacks.GetEffectiveBW()
|
||||
if effectiveBW < 0 {
|
||||
// Stale estimate — do nothing.
|
||||
return
|
||||
}
|
||||
|
||||
switch ls.state {
|
||||
case stateStable:
|
||||
ls.tickStable(effectiveBW)
|
||||
case stateProbingUp:
|
||||
ls.tickProbingUp(effectiveBW)
|
||||
case stateGraceDown:
|
||||
ls.tickGraceDown(effectiveBW)
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *LayerSelector) tickStable(effectiveBW float64) {
|
||||
// Check for upswitch opportunity.
|
||||
nextLayer := ls.currentLayer + 1
|
||||
if nextLayer <= ls.maxLayer && nextLayer <= 2 {
|
||||
if !ls.inCooldown() && effectiveBW > layerBitrates[nextLayer].UpThresh {
|
||||
ls.state = stateProbingUp
|
||||
ls.probeTarget = nextLayer
|
||||
ls.probeStartTime = time.Now()
|
||||
ls.callbacks.Log("INFO", "Participant %d<-%d: STABLE->PROBING_UP (BW=%.0fkbps, target=layer%d@%.0fkbps)",
|
||||
ls.receiverID, ls.senderID, effectiveBW/1000, nextLayer, layerBitrates[nextLayer].UpThresh/1000)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check for downswitch need.
|
||||
if ls.currentLayer > 0 {
|
||||
if effectiveBW < layerBitrates[ls.currentLayer].DownThresh {
|
||||
ls.state = stateGraceDown
|
||||
ls.graceStartTime = time.Now()
|
||||
ls.callbacks.Log("INFO", "Participant %d<-%d: STABLE->GRACE_DOWN (BW=%.0fkbps, thresh=%.0fkbps)",
|
||||
ls.receiverID, ls.senderID, effectiveBW/1000, layerBitrates[ls.currentLayer].DownThresh/1000)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *LayerSelector) tickProbingUp(effectiveBW float64) {
|
||||
elapsed := time.Since(ls.probeStartTime)
|
||||
|
||||
// Abort if bandwidth dropped below current layer's nominal bitrate.
|
||||
if effectiveBW < layerBitrates[ls.currentLayer].Nominal {
|
||||
ls.state = stateStable
|
||||
ls.lastSwitchTime = time.Now() // enter cooldown
|
||||
ls.callbacks.Log("INFO", "Participant %d<-%d: PROBING_UP->STABLE (abort, BW=%.0fkbps < nominal=%.0fkbps)",
|
||||
ls.receiverID, ls.senderID, effectiveBW/1000, layerBitrates[ls.currentLayer].Nominal/1000)
|
||||
return
|
||||
}
|
||||
|
||||
// Probe complete — switch up.
|
||||
if elapsed >= probeDuration {
|
||||
if effectiveBW > layerBitrates[ls.probeTarget].Nominal {
|
||||
ls.callbacks.Log("INFO", "Participant %d<-%d: PROBING_UP->STABLE (success, switching to layer %d)",
|
||||
ls.receiverID, ls.senderID, ls.probeTarget)
|
||||
ls.switchLayer(ls.probeTarget)
|
||||
return
|
||||
}
|
||||
// BW not sufficient at end of probe — abort.
|
||||
ls.state = stateStable
|
||||
ls.lastSwitchTime = time.Now()
|
||||
ls.callbacks.Log("INFO", "Participant %d<-%d: PROBING_UP->STABLE (probe done but BW=%.0fkbps insufficient)",
|
||||
ls.receiverID, ls.senderID, effectiveBW/1000)
|
||||
return
|
||||
}
|
||||
|
||||
// Send RTX padding during probe.
|
||||
ls.sendProbePadding(elapsed)
|
||||
}
|
||||
|
||||
func (ls *LayerSelector) tickGraceDown(effectiveBW float64) {
|
||||
// If bandwidth recovered, cancel grace period.
|
||||
if effectiveBW >= layerBitrates[ls.currentLayer].DownThresh {
|
||||
ls.state = stateStable
|
||||
ls.callbacks.Log("INFO", "Participant %d<-%d: GRACE_DOWN->STABLE (recovered, BW=%.0fkbps)",
|
||||
ls.receiverID, ls.senderID, effectiveBW/1000)
|
||||
return
|
||||
}
|
||||
|
||||
// Grace period expired — downswitch.
|
||||
if time.Since(ls.graceStartTime) >= graceDownTimeout {
|
||||
targetLayer := ls.currentLayer - 1
|
||||
if targetLayer < 0 {
|
||||
targetLayer = 0
|
||||
}
|
||||
ls.callbacks.Log("INFO", "Participant %d<-%d: GRACE_DOWN->STABLE (downswitch to layer %d)",
|
||||
ls.receiverID, ls.senderID, targetLayer)
|
||||
ls.switchLayer(targetLayer)
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *LayerSelector) switchLayer(newLayer int) {
|
||||
oldLayer := ls.currentLayer
|
||||
ls.currentLayer = newLayer
|
||||
ls.state = stateStable
|
||||
ls.lastSwitchTime = time.Now()
|
||||
ls.callbacks.SetSelectedLayer(newLayer)
|
||||
|
||||
// Request keyframe at the new layer.
|
||||
layers := ls.callbacks.GetSenderVideoLayers()
|
||||
if newLayer < len(layers) {
|
||||
ls.callbacks.SendPLI(layers[newLayer].SSRC)
|
||||
ls.callbacks.Log("INFO", "Participant %d<-%d: switched layer %d->%d (PLI sent for SSRC=%d)",
|
||||
ls.receiverID, ls.senderID, oldLayer, newLayer, layers[newLayer].SSRC)
|
||||
}
|
||||
}
|
||||
|
||||
func (ls *LayerSelector) inCooldown() bool {
|
||||
return !ls.lastSwitchTime.IsZero() && time.Since(ls.lastSwitchTime) < cooldownDuration
|
||||
}
|
||||
|
||||
func (ls *LayerSelector) sendProbePadding(elapsed time.Duration) {
|
||||
// Calculate target padding rate: ramp from 0 to gap over probeDuration.
|
||||
gap := layerBitrates[ls.probeTarget].Nominal - layerBitrates[ls.currentLayer].Nominal
|
||||
progress := float64(elapsed) / float64(probeDuration)
|
||||
targetBps := gap * progress
|
||||
|
||||
// How many bytes to send in this 100ms tick.
|
||||
bytesPerTick := targetBps / 8 / (float64(time.Second) / float64(tickInterval))
|
||||
|
||||
rtxBuf := ls.callbacks.GetRtxBuffer()
|
||||
if rtxBuf == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Pull packets from the ring buffer to fill the target bytes.
|
||||
entries := rtxBuf.Get(20) // enough for one tick
|
||||
if len(entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
layers := ls.callbacks.GetSenderVideoLayers()
|
||||
if ls.currentLayer >= len(layers) {
|
||||
return
|
||||
}
|
||||
rtxSSRC := layers[ls.currentLayer].FidSSRC
|
||||
if rtxSSRC == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var sentBytes float64
|
||||
entryIdx := 0
|
||||
for sentBytes < bytesPerTick && entryIdx < len(entries) {
|
||||
entry := entries[entryIdx]
|
||||
entryIdx++
|
||||
rtxPayload := rtxEncapsulate(entry.Payload, entry.SeqNum)
|
||||
ls.probeRtxSeq++
|
||||
ls.callbacks.SendRtxPadding(rtxPayload, rtxSSRC, ls.probeRtxSeq, entry.Timestamp)
|
||||
sentBytes += float64(len(rtxPayload))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,249 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockCallbacks is a test harness for driving the LayerSelector. BW is
|
||||
// atomic so the selector's run() goroutine can read while the test writes.
|
||||
type mockCallbacks struct {
|
||||
bw atomic.Int64 // current effective bandwidth in bps; negative = stale
|
||||
layers []SimulcastLayer
|
||||
selectedLayer atomic.Int32
|
||||
pliCount atomic.Int32
|
||||
probePaddings atomic.Int32
|
||||
|
||||
logMu sync.Mutex
|
||||
logBuf []string
|
||||
}
|
||||
|
||||
func newMockCallbacks(layers []SimulcastLayer, initialBW float64) *mockCallbacks {
|
||||
m := &mockCallbacks{layers: layers}
|
||||
m.bw.Store(int64(initialBW))
|
||||
m.selectedLayer.Store(-1)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockCallbacks) setBW(bps float64) { m.bw.Store(int64(bps)) }
|
||||
|
||||
func (m *mockCallbacks) currentSelected() int { return int(m.selectedLayer.Load()) }
|
||||
|
||||
func (m *mockCallbacks) toCallbacks() LayerSelectorCallbacks {
|
||||
return LayerSelectorCallbacks{
|
||||
GetEffectiveBW: func() float64 {
|
||||
v := float64(m.bw.Load())
|
||||
if v < 0 {
|
||||
return -1
|
||||
}
|
||||
return v
|
||||
},
|
||||
SetSelectedLayer: func(layer int) {
|
||||
m.selectedLayer.Store(int32(layer))
|
||||
},
|
||||
SendPLI: func(ssrc uint32) {
|
||||
m.pliCount.Add(1)
|
||||
},
|
||||
GetSenderVideoLayers: func() []SimulcastLayer {
|
||||
return m.layers
|
||||
},
|
||||
GetRtxBuffer: func() *RtxRingBuffer {
|
||||
return nil // probing padding no-ops without a buffer
|
||||
},
|
||||
SendRtxPadding: func(rtxPayload []byte, rtxSSRC uint32, seqNum uint16, timestamp uint32) {
|
||||
m.probePaddings.Add(1)
|
||||
},
|
||||
Log: func(level string, format string, args ...interface{}) {
|
||||
m.logMu.Lock()
|
||||
m.logBuf = append(m.logBuf, fmt.Sprintf("["+level+"] "+format, args...))
|
||||
m.logMu.Unlock()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// waitForLayer polls the selector's currentLayer up to timeout for a change
|
||||
// to `want`. Returns true if reached, false on timeout.
|
||||
func waitForLayer(ls *LayerSelector, want int, timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if ls.CurrentLayer() == want {
|
||||
return true
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func testLayers() []SimulcastLayer {
|
||||
return []SimulcastLayer{
|
||||
{SSRC: 1001, FidSSRC: 1002},
|
||||
{SSRC: 1003, FidSSRC: 1004},
|
||||
{SSRC: 1005, FidSSRC: 1006},
|
||||
}
|
||||
}
|
||||
|
||||
// TestLayerSelectorUpswitch verifies L0 -> L1 -> L2 based on rising BW.
|
||||
//
|
||||
// Thresholds (from layerBitrates):
|
||||
//
|
||||
// L1 UpThresh = 132 kbps → needs REMB > ~155 kbps (with 0.85 safety factor)
|
||||
// L2 UpThresh = 1080 kbps → needs REMB > ~1271 kbps
|
||||
//
|
||||
// The selector's state machine enforces a 5s cooldown after each switch, so
|
||||
// the whole test runs in ~8-10 seconds.
|
||||
func TestLayerSelectorUpswitch(t *testing.T) {
|
||||
m := newMockCallbacks(testLayers(), 200_000) // > L1 UpThresh
|
||||
ls := NewLayerSelector(1, 0, 0, 2, m.toCallbacks())
|
||||
defer ls.Stop()
|
||||
|
||||
// L0 -> L1: should enter PROBING_UP within 150ms (one tick), then
|
||||
// complete the 2s probe and switch to L1.
|
||||
if !waitForLayer(ls, 1, 3*time.Second) {
|
||||
t.Fatalf("L0->L1 upswitch timed out; currentLayer=%d selected=%d", ls.CurrentLayer(), m.currentSelected())
|
||||
}
|
||||
if got := m.currentSelected(); got != 1 {
|
||||
t.Fatalf("after L1 upswitch, SetSelectedLayer was not called with 1 (got %d)", got)
|
||||
}
|
||||
if pli := m.pliCount.Load(); pli < 1 {
|
||||
t.Fatalf("expected at least 1 PLI on layer switch, got %d", pli)
|
||||
}
|
||||
|
||||
// L1 -> L2: raise BW above L2 UpThresh. Wait out the 5s cooldown and
|
||||
// then the 2s probe (total ~7-8s).
|
||||
m.setBW(1_500_000)
|
||||
if !waitForLayer(ls, 2, 10*time.Second) {
|
||||
t.Fatalf("L1->L2 upswitch timed out; currentLayer=%d", ls.CurrentLayer())
|
||||
}
|
||||
if got := m.currentSelected(); got != 2 {
|
||||
t.Fatalf("after L2 upswitch, SetSelectedLayer was not called with 2 (got %d)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLayerSelectorDownswitch verifies L2 -> L1 -> L0 based on falling BW.
|
||||
// Starts the selector pre-positioned at L2 by setting its state directly
|
||||
// via `switchLayer`-equivalent initial-layer argument, then drives BW down.
|
||||
//
|
||||
// Thresholds:
|
||||
//
|
||||
// L2 DownThresh = 630 kbps → needs REMB < ~741 kbps
|
||||
// L1 DownThresh = 77 kbps → needs REMB < ~91 kbps
|
||||
//
|
||||
// Downswitches are governed by a 500ms grace period, no cooldown, so this
|
||||
// test runs in ~1.5 seconds.
|
||||
func TestLayerSelectorDownswitch(t *testing.T) {
|
||||
m := newMockCallbacks(testLayers(), 1_500_000) // high BW, at L2
|
||||
ls := NewLayerSelector(1, 0, 2, 2, m.toCallbacks())
|
||||
defer ls.Stop()
|
||||
|
||||
// Drop BW below L2 downswitch threshold. Effective = 500k * 0.85 = 425k
|
||||
// is NOT below 630k effective threshold directly. Use 600k raw so
|
||||
// effective = 510k, well below 630k.
|
||||
m.setBW(600_000)
|
||||
if !waitForLayer(ls, 1, 2*time.Second) {
|
||||
t.Fatalf("L2->L1 downswitch timed out; currentLayer=%d", ls.CurrentLayer())
|
||||
}
|
||||
if got := m.currentSelected(); got != 1 {
|
||||
t.Fatalf("after L1 downswitch, SetSelectedLayer was not called with 1 (got %d)", got)
|
||||
}
|
||||
|
||||
// Drop below L1 downswitch threshold (77k effective → raw < 91k).
|
||||
// Use 50k raw → effective 42k.
|
||||
m.setBW(50_000)
|
||||
if !waitForLayer(ls, 0, 2*time.Second) {
|
||||
t.Fatalf("L1->L0 downswitch timed out; currentLayer=%d", ls.CurrentLayer())
|
||||
}
|
||||
if got := m.currentSelected(); got != 0 {
|
||||
t.Fatalf("after L0 downswitch, SetSelectedLayer was not called with 0 (got %d)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLayerSelectorGraceDownRecovery verifies that a transient BW dip that
|
||||
// recovers within the 500ms grace window does NOT cause a downswitch.
|
||||
func TestLayerSelectorGraceDownRecovery(t *testing.T) {
|
||||
m := newMockCallbacks(testLayers(), 1_500_000)
|
||||
ls := NewLayerSelector(1, 0, 2, 2, m.toCallbacks())
|
||||
defer ls.Stop()
|
||||
|
||||
// Dip below downthresh, then recover before grace expires.
|
||||
m.setBW(500_000) // below L2 downthresh
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
m.setBW(1_500_000) // recovered
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
if got := ls.CurrentLayer(); got != 2 {
|
||||
t.Fatalf("transient dip should not have downswitched; currentLayer=%d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLayerSelectorStaleBW verifies that with no REMB data (BW=-1), the
|
||||
// state machine does not transition.
|
||||
func TestLayerSelectorStaleBW(t *testing.T) {
|
||||
m := newMockCallbacks(testLayers(), -1) // stale
|
||||
ls := NewLayerSelector(1, 0, 1, 2, m.toCallbacks())
|
||||
defer ls.Stop()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
if got := ls.CurrentLayer(); got != 1 {
|
||||
t.Fatalf("stale BW should not trigger a transition; currentLayer=%d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLayerSelectorOnMaxActiveLayerIncreasedWhenStale verifies the fallback
|
||||
// path: when BW is stale (clients don't send REMB), discovery of a higher
|
||||
// active layer from the sender causes an immediate upswitch.
|
||||
func TestLayerSelectorOnMaxActiveLayerIncreasedWhenStale(t *testing.T) {
|
||||
m := newMockCallbacks(testLayers(), -1)
|
||||
ls := NewLayerSelector(1, 0, 1, 2, m.toCallbacks())
|
||||
defer ls.Stop()
|
||||
|
||||
// Nothing has happened yet.
|
||||
if got := ls.CurrentLayer(); got != 1 {
|
||||
t.Fatalf("unexpected initial layer %d", got)
|
||||
}
|
||||
|
||||
// Sender starts producing L2. With stale BW, we should upshift
|
||||
// immediately up to the receiver's requested maxLayer.
|
||||
ls.OnMaxActiveLayerIncreased(2)
|
||||
if got := ls.CurrentLayer(); got != 2 {
|
||||
t.Fatalf("expected upshift to L2 on maxActive increase with stale BW; got %d", got)
|
||||
}
|
||||
if got := m.currentSelected(); got != 2 {
|
||||
t.Fatalf("SetSelectedLayer should have been called with 2; got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLayerSelectorOnMaxActiveLayerIncreasedWhenFresh verifies that when BW
|
||||
// is fresh, OnMaxActiveLayerIncreased is a no-op — the state machine is in
|
||||
// charge of layer selection.
|
||||
func TestLayerSelectorOnMaxActiveLayerIncreasedWhenFresh(t *testing.T) {
|
||||
m := newMockCallbacks(testLayers(), 200_000) // fresh, enough for L1 only
|
||||
ls := NewLayerSelector(1, 0, 1, 2, m.toCallbacks())
|
||||
defer ls.Stop()
|
||||
|
||||
ls.OnMaxActiveLayerIncreased(2)
|
||||
if got := ls.CurrentLayer(); got != 1 {
|
||||
t.Fatalf("fresh BW should leave state machine in charge; current=%d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLayerSelectorRespectsMaxLayer verifies that upswitches never exceed
|
||||
// the receiver's requested maxLayer.
|
||||
func TestLayerSelectorRespectsMaxLayer(t *testing.T) {
|
||||
m := newMockCallbacks(testLayers(), 2_000_000) // way more than needed for L2
|
||||
ls := NewLayerSelector(1, 0, 0, 1, m.toCallbacks())
|
||||
defer ls.Stop()
|
||||
|
||||
// Wait long enough for an L0->L1 upswitch (~2.2s). Then wait past the
|
||||
// cooldown (5s) plus another probe window (2s) to ensure the selector
|
||||
// does NOT attempt to probe beyond maxLayer=1.
|
||||
if !waitForLayer(ls, 1, 3*time.Second) {
|
||||
t.Fatalf("L0->L1 upswitch timed out")
|
||||
}
|
||||
time.Sleep(8 * time.Second)
|
||||
if got := ls.CurrentLayer(); got != 1 {
|
||||
t.Fatalf("selector upshifted beyond maxLayer=1; got %d", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
module github.com/nicegram/AltTgCalls/tools/go_sfu
|
||||
|
||||
go 1.24.2
|
||||
|
||||
require (
|
||||
github.com/pion/datachannel v1.5.10
|
||||
github.com/pion/dtls/v3 v3.0.6
|
||||
github.com/pion/ice/v4 v4.0.7
|
||||
github.com/pion/logging v0.2.3
|
||||
github.com/pion/rtcp v1.2.15
|
||||
github.com/pion/sctp v1.8.37
|
||||
github.com/pion/srtp/v3 v3.0.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||
github.com/pion/randutil v0.1.0 // indirect
|
||||
github.com/pion/rtp v1.8.17 // indirect
|
||||
github.com/pion/stun/v3 v3.0.0 // indirect
|
||||
github.com/pion/transport/v3 v3.0.7 // indirect
|
||||
github.com/pion/turn/v4 v4.0.0 // indirect
|
||||
github.com/wlynxg/anet v0.0.3 // indirect
|
||||
golang.org/x/crypto v0.32.0 // indirect
|
||||
golang.org/x/net v0.34.0 // indirect
|
||||
golang.org/x/sys v0.29.0 // indirect
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o=
|
||||
github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M=
|
||||
github.com/pion/dtls/v3 v3.0.6 h1:7Hkd8WhAJNbRgq9RgdNh1aaWlZlGpYTzdqjy9x9sK2E=
|
||||
github.com/pion/dtls/v3 v3.0.6/go.mod h1:iJxNQ3Uhn1NZWOMWlLxEEHAN5yX7GyPvvKw04v9bzYU=
|
||||
github.com/pion/ice/v4 v4.0.7 h1:mnwuT3n3RE/9va41/9QJqN5+Bhc0H/x/ZyiVlWMw35M=
|
||||
github.com/pion/ice/v4 v4.0.7/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw=
|
||||
github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI=
|
||||
github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90=
|
||||
github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM=
|
||||
github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA=
|
||||
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||
github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo=
|
||||
github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0=
|
||||
github.com/pion/rtp v1.8.17 h1:CFhaPN8Ikt9Sk7B3pic0kfwVia2dUMEtPSL34Gvihjw=
|
||||
github.com/pion/rtp v1.8.17/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk=
|
||||
github.com/pion/sctp v1.8.37 h1:ZDmGPtRPX9mKCiVXtMbTWybFw3z/hVKAZgU81wcOrqs=
|
||||
github.com/pion/sctp v1.8.37/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE=
|
||||
github.com/pion/srtp/v3 v3.0.5 h1:8XLB6Dt3QXkMkRFpoqC3314BemkpMQK2mZeJc4pUKqo=
|
||||
github.com/pion/srtp/v3 v3.0.5/go.mod h1:r1G7y5r1scZRLe2QJI/is+/O83W2d+JoEsuIexpw+uM=
|
||||
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
|
||||
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
|
||||
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
|
||||
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
|
||||
github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM=
|
||||
github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
|
||||
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
||||
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
||||
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -0,0 +1,239 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
muxReadBufSize = 8192
|
||||
muxChanBufSize = 256
|
||||
)
|
||||
|
||||
// isDTLS returns true if the first byte indicates a DTLS record (RFC 7983: 20–63).
|
||||
func isDTLS(b byte) bool {
|
||||
return b >= 20 && b <= 63
|
||||
}
|
||||
|
||||
// isRTPOrRTCP returns true if the first byte indicates an RTP/RTCP packet (RFC 7983: 128–191).
|
||||
func isRTPOrRTCP(b byte) bool {
|
||||
return b >= 128 && b <= 191
|
||||
}
|
||||
|
||||
// isRTCP returns true if the packet is RTCP (not RTP) per RFC 5761 Section 4.
|
||||
// RTCP packet types (byte[1]) are 200-211. RTP with Marker=1 and dynamic PT >= 96
|
||||
// gives byte[1] >= 224, so we use byte[1] >= 200 && byte[1] < 224 to exclude RTP.
|
||||
// In SRTCP the fixed header is unencrypted, so byte[1] is readable.
|
||||
func isRTCP(pkt []byte) bool {
|
||||
return len(pkt) >= 2 && pkt[1] >= 200 && pkt[1] < 224
|
||||
}
|
||||
|
||||
// PacketDemux reads from a net.Conn and routes packets to separate DTLS,
|
||||
// SRTP (RTP only), and RTCP channels based on RFC 7983 first-byte classification
|
||||
// and RTP/RTCP payload type demux.
|
||||
type PacketDemux struct {
|
||||
conn net.Conn
|
||||
dtlsCh chan []byte
|
||||
srtpCh chan []byte
|
||||
rtcpCh chan []byte
|
||||
once sync.Once
|
||||
closed chan struct{}
|
||||
label string
|
||||
}
|
||||
|
||||
func (d *PacketDemux) logf(format string, args ...interface{}) {
|
||||
fmt.Printf("[demux-%s] %s\n", d.label, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
// NewPacketDemux creates a PacketDemux and starts the read loop goroutine.
|
||||
func NewPacketDemux(conn net.Conn, label string) *PacketDemux {
|
||||
d := &PacketDemux{
|
||||
conn: conn,
|
||||
dtlsCh: make(chan []byte, muxChanBufSize),
|
||||
srtpCh: make(chan []byte, muxChanBufSize),
|
||||
rtcpCh: make(chan []byte, muxChanBufSize),
|
||||
closed: make(chan struct{}),
|
||||
label: label,
|
||||
}
|
||||
go d.readLoop()
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *PacketDemux) readLoop() {
|
||||
buf := make([]byte, muxReadBufSize)
|
||||
dtlsCount := 0
|
||||
srtpCount := 0
|
||||
rtcpCount := 0
|
||||
otherCount := 0
|
||||
for {
|
||||
n, err := d.conn.Read(buf)
|
||||
if err != nil {
|
||||
d.Close()
|
||||
return
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
pkt := make([]byte, n)
|
||||
copy(pkt, buf[:n])
|
||||
|
||||
switch {
|
||||
case isDTLS(pkt[0]):
|
||||
dtlsCount++
|
||||
if dtlsCount <= 5 {
|
||||
d.logf("DTLS packet #%d: %d bytes (first byte: 0x%02x)", dtlsCount, n, pkt[0])
|
||||
}
|
||||
select {
|
||||
case d.dtlsCh <- pkt:
|
||||
default:
|
||||
d.logf("DTLS channel full, dropping packet")
|
||||
}
|
||||
case isRTPOrRTCP(pkt[0]):
|
||||
if isRTCP(pkt) {
|
||||
rtcpCount++
|
||||
if rtcpCount <= 3 {
|
||||
d.logf("RTCP packet #%d: %d bytes (type byte: 0x%02x)", rtcpCount, n, pkt[1])
|
||||
}
|
||||
select {
|
||||
case d.rtcpCh <- pkt:
|
||||
default:
|
||||
// drop if channel full
|
||||
}
|
||||
} else {
|
||||
srtpCount++
|
||||
if srtpCount == 1 {
|
||||
d.logf("First SRTP packet: %d bytes", n)
|
||||
}
|
||||
select {
|
||||
case d.srtpCh <- pkt:
|
||||
default:
|
||||
// drop if channel full
|
||||
}
|
||||
}
|
||||
default:
|
||||
otherCount++
|
||||
if otherCount <= 3 {
|
||||
d.logf("Other packet: %d bytes (first byte: 0x%02x)", n, pkt[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the demuxer and the underlying connection.
|
||||
func (d *PacketDemux) Close() error {
|
||||
var err error
|
||||
d.once.Do(func() {
|
||||
close(d.closed)
|
||||
err = d.conn.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// DTLSEndpoint returns a net.Conn that yields only DTLS packets.
|
||||
func (d *PacketDemux) DTLSEndpoint() net.Conn {
|
||||
return &demuxEndpoint{demux: d, ch: d.dtlsCh}
|
||||
}
|
||||
|
||||
// SRTPEndpoint returns a net.Conn that yields only SRTP (RTP) packets.
|
||||
// RTCP packets are routed to RTCPChannel() instead.
|
||||
func (d *PacketDemux) SRTPEndpoint() net.Conn {
|
||||
return &demuxEndpoint{demux: d, ch: d.srtpCh}
|
||||
}
|
||||
|
||||
// RTCPChannel returns a channel that receives raw encrypted SRTCP packets.
|
||||
// These must be decrypted externally (not via SessionSRTP which only handles RTP).
|
||||
func (d *PacketDemux) RTCPChannel() <-chan []byte {
|
||||
return d.rtcpCh
|
||||
}
|
||||
|
||||
// demuxEndpoint implements net.Conn for a single demux channel.
|
||||
type demuxEndpoint struct {
|
||||
demux *PacketDemux
|
||||
ch chan []byte
|
||||
mu sync.Mutex
|
||||
leftover []byte
|
||||
}
|
||||
|
||||
func (e *demuxEndpoint) Read(b []byte) (int, error) {
|
||||
e.mu.Lock()
|
||||
if len(e.leftover) > 0 {
|
||||
n := copy(b, e.leftover)
|
||||
e.leftover = e.leftover[n:]
|
||||
if len(e.leftover) == 0 {
|
||||
e.leftover = nil
|
||||
}
|
||||
e.mu.Unlock()
|
||||
return n, nil
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-e.demux.closed:
|
||||
return 0, io.EOF
|
||||
case pkt, ok := <-e.ch:
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(b, pkt)
|
||||
if n < len(pkt) {
|
||||
e.mu.Lock()
|
||||
e.leftover = pkt[n:]
|
||||
e.mu.Unlock()
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *demuxEndpoint) Write(b []byte) (int, error) {
|
||||
return e.demux.conn.Write(b)
|
||||
}
|
||||
|
||||
func (e *demuxEndpoint) Close() error {
|
||||
return e.demux.Close()
|
||||
}
|
||||
|
||||
func (e *demuxEndpoint) LocalAddr() net.Addr {
|
||||
return e.demux.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (e *demuxEndpoint) RemoteAddr() net.Addr {
|
||||
return e.demux.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (e *demuxEndpoint) SetDeadline(t time.Time) error {
|
||||
return e.demux.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (e *demuxEndpoint) SetReadDeadline(t time.Time) error {
|
||||
return e.demux.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (e *demuxEndpoint) SetWriteDeadline(t time.Time) error {
|
||||
return e.demux.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// connToPacketConn wraps a net.Conn into a net.PacketConn.
|
||||
// It is used to adapt a demuxEndpoint for pion/dtls.Server(), which
|
||||
// requires net.PacketConn. Since the endpoint is already bound to a
|
||||
// single peer, ReadFrom returns the conn's RemoteAddr and WriteTo ignores
|
||||
// the addr parameter.
|
||||
type connToPacketConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
// WrapAsPacketConn adapts a net.Conn to net.PacketConn.
|
||||
func WrapAsPacketConn(c net.Conn) net.PacketConn {
|
||||
return &connToPacketConn{Conn: c}
|
||||
}
|
||||
|
||||
func (c *connToPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
n, err := c.Conn.Read(b)
|
||||
return n, c.Conn.RemoteAddr(), err
|
||||
}
|
||||
|
||||
func (c *connToPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NetworkSimulator models a uni-directional network pipe with delay, jitter,
|
||||
// packet loss, and bandwidth cap (token bucket).
|
||||
type NetworkSimulator struct {
|
||||
mu sync.Mutex
|
||||
delayMs int
|
||||
jitterMs int
|
||||
dropRate float64
|
||||
bandwidthBps int64
|
||||
|
||||
// Token bucket for bandwidth cap.
|
||||
tokens float64 // available tokens (bits)
|
||||
maxTokens float64 // max tokens = 200ms worth of bandwidth
|
||||
lastRefill time.Time
|
||||
rng *rand.Rand
|
||||
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewNetworkSimulator creates a simulator with no simulation (passthrough).
|
||||
func NewNetworkSimulator() *NetworkSimulator {
|
||||
return &NetworkSimulator{
|
||||
lastRefill: time.Now(),
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// SetParams reconfigures the simulator at runtime. Thread-safe.
|
||||
func (ns *NetworkSimulator) SetParams(delayMs, jitterMs int, dropRate float64, bandwidthBps int64) {
|
||||
ns.mu.Lock()
|
||||
defer ns.mu.Unlock()
|
||||
ns.delayMs = delayMs
|
||||
ns.jitterMs = jitterMs
|
||||
ns.dropRate = dropRate
|
||||
ns.bandwidthBps = bandwidthBps
|
||||
if bandwidthBps > 0 {
|
||||
ns.maxTokens = float64(bandwidthBps) * 0.2 // 200ms buffer
|
||||
if ns.tokens > ns.maxTokens {
|
||||
ns.tokens = ns.maxTokens
|
||||
}
|
||||
} else {
|
||||
ns.maxTokens = 0
|
||||
ns.tokens = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Send processes a packet through the simulator. deliverFn is called
|
||||
// (possibly asynchronously) after simulation. The packet bytes are copied
|
||||
// if delivery is deferred.
|
||||
func (ns *NetworkSimulator) Send(pkt []byte, deliverFn func([]byte)) {
|
||||
ns.mu.Lock()
|
||||
if ns.closed {
|
||||
ns.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Drop check.
|
||||
if ns.dropRate > 0 && ns.rng.Float64() < ns.dropRate {
|
||||
ns.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Bandwidth cap: token bucket.
|
||||
if ns.bandwidthBps > 0 {
|
||||
ns.refillTokens()
|
||||
cost := float64(len(pkt)) * 8
|
||||
if ns.tokens < cost {
|
||||
// Queue full / no tokens — tail drop.
|
||||
ns.mu.Unlock()
|
||||
return
|
||||
}
|
||||
ns.tokens -= cost
|
||||
}
|
||||
|
||||
// Calculate delay.
|
||||
delayMs := ns.delayMs
|
||||
if ns.jitterMs > 0 {
|
||||
delayMs += ns.rng.Intn(2*ns.jitterMs+1) - ns.jitterMs
|
||||
if delayMs < 0 {
|
||||
delayMs = 0
|
||||
}
|
||||
}
|
||||
ns.mu.Unlock()
|
||||
|
||||
if delayMs == 0 {
|
||||
deliverFn(pkt)
|
||||
return
|
||||
}
|
||||
|
||||
// Copy packet for deferred delivery.
|
||||
pktCopy := make([]byte, len(pkt))
|
||||
copy(pktCopy, pkt)
|
||||
time.AfterFunc(time.Duration(delayMs)*time.Millisecond, func() {
|
||||
deliverFn(pktCopy)
|
||||
})
|
||||
}
|
||||
|
||||
// Close stops the simulator. Pending delayed packets may still fire.
|
||||
func (ns *NetworkSimulator) Close() {
|
||||
ns.mu.Lock()
|
||||
ns.closed = true
|
||||
ns.mu.Unlock()
|
||||
}
|
||||
|
||||
// refillTokens adds tokens based on elapsed time. Must be called with mu held.
|
||||
func (ns *NetworkSimulator) refillTokens() {
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(ns.lastRefill).Seconds()
|
||||
ns.lastRefill = now
|
||||
ns.tokens += float64(ns.bandwidthBps) * elapsed
|
||||
if ns.tokens > ns.maxTokens {
|
||||
ns.tokens = ns.maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
// IsPassthrough returns true if no simulation is configured.
|
||||
func (ns *NetworkSimulator) IsPassthrough() bool {
|
||||
ns.mu.Lock()
|
||||
defer ns.mu.Unlock()
|
||||
return ns.delayMs == 0 && ns.jitterMs == 0 && ns.dropRate == 0 && ns.bandwidthBps == 0
|
||||
}
|
||||
@@ -0,0 +1,637 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/datachannel"
|
||||
"github.com/pion/dtls/v3"
|
||||
"github.com/pion/ice/v4"
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/sctp"
|
||||
"github.com/pion/srtp/v3"
|
||||
)
|
||||
|
||||
// ParticipantConfig holds the client's join parameters extracted from the join payload.
|
||||
type ParticipantConfig struct {
|
||||
AudioSSRC uint32
|
||||
Ufrag string
|
||||
Pwd string
|
||||
Fingerprint string // SHA-256, colon-separated uppercase hex (e.g., "AB:CD:EF:...")
|
||||
}
|
||||
|
||||
// Participant holds the per-participant transport stack: ICE → DTLS → SRTP + SCTP/DataChannel.
|
||||
type Participant struct {
|
||||
ID int
|
||||
AudioSSRC uint32
|
||||
|
||||
iceAgent *ice.Agent
|
||||
iceConn *ice.Conn
|
||||
|
||||
demux *PacketDemux
|
||||
dtlsConn *dtls.Conn
|
||||
|
||||
srtpSession *srtp.SessionSRTP
|
||||
srtpWriter *srtp.WriteStreamSRTP
|
||||
srtpProfile srtp.ProtectionProfile
|
||||
srtpKeys srtp.SessionKeys // saved for creating SRTCP contexts
|
||||
|
||||
// Separate SRTCP contexts for manual RTCP decrypt/encrypt.
|
||||
// These are independent from the SessionSRTP used for RTP.
|
||||
srtcpRemoteCtx *srtp.Context // decrypt SRTCP received from this participant
|
||||
srtcpLocalCtx *srtp.Context // encrypt SRTCP sent to this participant
|
||||
srtcpMu sync.Mutex // protects srtcpLocalCtx (single-writer)
|
||||
|
||||
sctpAssoc *sctp.Association
|
||||
dataChannel *datachannel.DataChannel
|
||||
|
||||
tlsCert tls.Certificate
|
||||
fingerprint string // SHA-256, colon-separated uppercase hex
|
||||
localUfrag string
|
||||
localPwd string
|
||||
|
||||
loggerFactory logging.LoggerFactory
|
||||
log logging.LeveledLogger
|
||||
|
||||
// Video layer selection: receiver requests which layer to receive from each sender.
|
||||
videoLayerMu sync.RWMutex
|
||||
requestedLayers map[int]int // senderID -> layer index
|
||||
|
||||
// Bandwidth estimation from REMB.
|
||||
bwEstimator *BandwidthEstimator
|
||||
|
||||
// Selected layers: what the SFU actually forwards (set by LayerSelector).
|
||||
selectedLayerMu sync.RWMutex
|
||||
selectedLayers map[int]int // senderID -> layer index
|
||||
onColibriMessage func(participantID int, msg string) // set before Connect(), read from acceptDataChannel goroutine
|
||||
|
||||
// RTCP feedback callback: called when PLI or FIR is received from this participant.
|
||||
// mediaSSRC is the SSRC the receiver wants a keyframe for.
|
||||
onRTCPFeedback func(participantID int, mediaSSRC uint32, isFIR bool)
|
||||
|
||||
// Network simulation (delay/jitter/loss/bandwidth cap per direction).
|
||||
ingressSim *NetworkSimulator
|
||||
egressSim *NetworkSimulator
|
||||
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// NewParticipant creates a new Participant with an ICE agent and self-signed certificate.
|
||||
// It does NOT start ICE gathering or connection — call GatherCandidates() and Connect() for that.
|
||||
func NewParticipant(id int, config ParticipantConfig, loggerFactory logging.LoggerFactory) (*Participant, error) {
|
||||
log := loggerFactory.NewLogger(fmt.Sprintf("participant-%d", id))
|
||||
|
||||
// Generate self-signed ECDSA P-256 certificate.
|
||||
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate ECDSA key: %w", err)
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create certificate: %w", err)
|
||||
}
|
||||
|
||||
tlsCert := tls.Certificate{
|
||||
Certificate: [][]byte{certDER},
|
||||
PrivateKey: privKey,
|
||||
}
|
||||
|
||||
// Compute SHA-256 fingerprint of the DER certificate.
|
||||
hash := sha256.Sum256(certDER)
|
||||
fingerprint := formatFingerprint(hash[:])
|
||||
|
||||
// Create ICE agent — UDP, host candidates, ICE-lite.
|
||||
// The tgcalls GroupNetworkManager hardcodes ICEROLE_CONTROLLED for the client,
|
||||
// so the SFU must be the controlling side (use Dial, not Accept).
|
||||
// ICE-lite: the SFU passively accepts incoming connectivity checks.
|
||||
// No remote candidates needed: when the client's STUN binding requests arrive,
|
||||
// pion creates peer-reflexive candidates automatically.
|
||||
agent, err := ice.NewAgent(&ice.AgentConfig{
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4},
|
||||
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost},
|
||||
Lite: true,
|
||||
IncludeLoopback: true,
|
||||
IPFilter: func(ip net.IP) bool {
|
||||
return true // accept all interfaces, including loopback
|
||||
},
|
||||
LoggerFactory: loggerFactory,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ICE agent: %w", err)
|
||||
}
|
||||
|
||||
localUfrag, localPwd, err := agent.GetLocalUserCredentials()
|
||||
if err != nil {
|
||||
_ = agent.Close()
|
||||
return nil, fmt.Errorf("get local credentials: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Created participant %d (SSRC=%d, ufrag=%s, fingerprint=%s)", id, config.AudioSSRC, localUfrag, fingerprint)
|
||||
|
||||
return &Participant{
|
||||
ID: id,
|
||||
AudioSSRC: config.AudioSSRC,
|
||||
iceAgent: agent,
|
||||
tlsCert: tlsCert,
|
||||
fingerprint: fingerprint,
|
||||
localUfrag: localUfrag,
|
||||
localPwd: localPwd,
|
||||
loggerFactory: loggerFactory,
|
||||
log: log,
|
||||
requestedLayers: make(map[int]int),
|
||||
bwEstimator: &BandwidthEstimator{},
|
||||
selectedLayers: make(map[int]int),
|
||||
ingressSim: NewNetworkSimulator(),
|
||||
egressSim: NewNetworkSimulator(),
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Fingerprint returns the SHA-256 fingerprint of the participant's DTLS certificate.
|
||||
func (p *Participant) Fingerprint() string {
|
||||
return p.fingerprint
|
||||
}
|
||||
|
||||
// LocalUfrag returns the local ICE username fragment.
|
||||
func (p *Participant) LocalUfrag() string {
|
||||
return p.localUfrag
|
||||
}
|
||||
|
||||
// LocalPwd returns the local ICE password.
|
||||
func (p *Participant) LocalPwd() string {
|
||||
return p.localPwd
|
||||
}
|
||||
|
||||
// GatherCandidates triggers ICE gathering and waits for completion.
|
||||
// Returns the gathered ICE candidates.
|
||||
func (p *Participant) GatherCandidates() ([]ice.Candidate, error) {
|
||||
var (
|
||||
candidates []ice.Candidate
|
||||
mu sync.Mutex
|
||||
done = make(chan struct{})
|
||||
)
|
||||
|
||||
if err := p.iceAgent.OnCandidate(func(c ice.Candidate) {
|
||||
if c == nil {
|
||||
// nil candidate signals gathering complete.
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
candidates = append(candidates, c)
|
||||
mu.Unlock()
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("set OnCandidate: %w", err)
|
||||
}
|
||||
|
||||
if err := p.iceAgent.GatherCandidates(); err != nil {
|
||||
return nil, fmt.Errorf("gather candidates: %w", err)
|
||||
}
|
||||
|
||||
<-done
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
p.log.Infof("Gathered %d ICE candidates", len(candidates))
|
||||
return candidates, nil
|
||||
}
|
||||
|
||||
// Connect establishes the full transport stack: ICE → DTLS → SRTP + SCTP.
|
||||
// The SFU is DTLS client (active). tgcalls GroupNetworkManager hardcodes SSL_SERVER.
|
||||
//
|
||||
// iceControlling selects the ICE role:
|
||||
// - true (Dial): SFU is controlling. Required for tgcalls GroupNetworkManager which
|
||||
// hardcodes ICEROLE_CONTROLLED (non-standard).
|
||||
// - false (Accept): SFU is controlled (standard for ICE-lite). Required for PeerConnection
|
||||
// clients that follow RFC 8445 (full agent = controlling when remote is ice-lite).
|
||||
func (p *Participant) Connect(ctx context.Context, remoteUfrag, remotePwd string, iceControlling bool) error {
|
||||
// 1. ICE connection.
|
||||
var iceConn *ice.Conn
|
||||
var err error
|
||||
if iceControlling {
|
||||
iceConn, err = p.iceAgent.Dial(ctx, remoteUfrag, remotePwd)
|
||||
} else {
|
||||
iceConn, err = p.iceAgent.Accept(ctx, remoteUfrag, remotePwd)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("ICE dial: %w", err)
|
||||
}
|
||||
p.iceConn = iceConn
|
||||
p.log.Infof("ICE connected")
|
||||
|
||||
// 2. Demux: split DTLS and SRTP traffic.
|
||||
p.demux = NewPacketDemux(iceConn, fmt.Sprintf("p%d", p.ID))
|
||||
|
||||
// 3. DTLS: client-side handshake over the DTLS endpoint.
|
||||
// tgcalls GroupNetworkManager hardcodes SetDtlsRole(SSL_SERVER), so the SFU must be the DTLS client.
|
||||
dtlsEndpoint := p.demux.DTLSEndpoint()
|
||||
remoteAddr := dtlsEndpoint.RemoteAddr()
|
||||
packetConn := WrapAsPacketConn(dtlsEndpoint)
|
||||
|
||||
dtlsConn, err := dtls.Client(packetConn, remoteAddr, &dtls.Config{
|
||||
Certificates: []tls.Certificate{p.tlsCert},
|
||||
// Offer GCM profiles matching tgcalls GroupNetworkManager::getDefaulCryptoOptions()
|
||||
// which enables enable_gcm_crypto_suites=true and disables AES-128-CM-SHA1-80.
|
||||
SRTPProtectionProfiles: []dtls.SRTPProtectionProfile{
|
||||
dtls.SRTP_AEAD_AES_256_GCM,
|
||||
dtls.SRTP_AEAD_AES_128_GCM,
|
||||
},
|
||||
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||
InsecureSkipVerify: true, // tgcalls verifies fingerprint out-of-band; we skip TLS chain verification
|
||||
LoggerFactory: p.loggerFactory,
|
||||
})
|
||||
if err != nil {
|
||||
p.demux.Close()
|
||||
return fmt.Errorf("DTLS create: %w", err)
|
||||
}
|
||||
p.dtlsConn = dtlsConn
|
||||
|
||||
// dtls.Client() is lazy; explicitly run the handshake before accessing ConnectionState.
|
||||
if err := dtlsConn.HandshakeContext(ctx); err != nil {
|
||||
p.demux.Close()
|
||||
return fmt.Errorf("DTLS handshake: %w", err)
|
||||
}
|
||||
p.log.Infof("DTLS connected")
|
||||
|
||||
// 4. Extract SRTP keying material from DTLS.
|
||||
state, ok := dtlsConn.ConnectionState()
|
||||
if !ok {
|
||||
return fmt.Errorf("DTLS connection state not available")
|
||||
}
|
||||
|
||||
// Map the negotiated DTLS-SRTP protection profile to a pion/srtp ProtectionProfile.
|
||||
negotiatedProfile, profileOk := dtlsConn.SelectedSRTPProtectionProfile()
|
||||
if !profileOk {
|
||||
p.demux.Close()
|
||||
return fmt.Errorf("no SRTP protection profile negotiated")
|
||||
}
|
||||
var srtpProfile srtp.ProtectionProfile
|
||||
switch negotiatedProfile {
|
||||
case dtls.SRTP_AEAD_AES_256_GCM:
|
||||
srtpProfile = srtp.ProtectionProfileAeadAes256Gcm
|
||||
case dtls.SRTP_AEAD_AES_128_GCM:
|
||||
srtpProfile = srtp.ProtectionProfileAeadAes128Gcm
|
||||
case dtls.SRTP_AES128_CM_HMAC_SHA1_80:
|
||||
srtpProfile = srtp.ProtectionProfileAes128CmHmacSha1_80
|
||||
case dtls.SRTP_AES128_CM_HMAC_SHA1_32:
|
||||
srtpProfile = srtp.ProtectionProfileAes128CmHmacSha1_32
|
||||
default:
|
||||
p.demux.Close()
|
||||
return fmt.Errorf("unsupported SRTP protection profile: 0x%04x", negotiatedProfile)
|
||||
}
|
||||
p.log.Infof("Negotiated SRTP profile: 0x%04x", negotiatedProfile)
|
||||
|
||||
srtpConfig := &srtp.Config{
|
||||
Profile: srtpProfile,
|
||||
}
|
||||
// SFU is DTLS client → isClient=true
|
||||
if err := srtpConfig.ExtractSessionKeysFromDTLS(&state, true); err != nil {
|
||||
return fmt.Errorf("extract SRTP keys: %w", err)
|
||||
}
|
||||
|
||||
// Save keys and profile for creating SRTCP contexts.
|
||||
p.srtpProfile = srtpProfile
|
||||
p.srtpKeys = srtpConfig.Keys
|
||||
|
||||
// 5. SRTP session over the SRTP endpoint (RTP only — RTCP is handled separately).
|
||||
srtpEndpoint := p.demux.SRTPEndpoint()
|
||||
srtpSession, err := srtp.NewSessionSRTP(srtpEndpoint, srtpConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SRTP session: %w", err)
|
||||
}
|
||||
p.srtpSession = srtpSession
|
||||
|
||||
srtpWriter, err := srtpSession.OpenWriteStream()
|
||||
if err != nil {
|
||||
return fmt.Errorf("open SRTP write stream: %w", err)
|
||||
}
|
||||
p.srtpWriter = srtpWriter
|
||||
p.log.Infof("SRTP session established")
|
||||
|
||||
// 5b. Create separate SRTCP contexts for manual RTCP handling.
|
||||
// Remote context: decrypt SRTCP received from this participant (their local = our remote).
|
||||
p.srtcpRemoteCtx, err = srtp.CreateContext(
|
||||
p.srtpKeys.RemoteMasterKey, p.srtpKeys.RemoteMasterSalt, p.srtpProfile,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SRTCP remote context: %w", err)
|
||||
}
|
||||
// Local context: encrypt SRTCP we send to this participant (our local keys).
|
||||
p.srtcpLocalCtx, err = srtp.CreateContext(
|
||||
p.srtpKeys.LocalMasterKey, p.srtpKeys.LocalMasterSalt, p.srtpProfile,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SRTCP local context: %w", err)
|
||||
}
|
||||
p.log.Infof("SRTCP contexts created")
|
||||
|
||||
// 5c. Start RTCP read loop.
|
||||
go p.readRTCPLoop()
|
||||
|
||||
// 6. SCTP association over DTLS.
|
||||
sctpAssoc, err := sctp.Server(sctp.Config{
|
||||
NetConn: dtlsConn,
|
||||
LoggerFactory: p.loggerFactory,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SCTP association: %w", err)
|
||||
}
|
||||
p.sctpAssoc = sctpAssoc
|
||||
p.log.Infof("SCTP association established")
|
||||
|
||||
// 7. Start goroutine to accept data channels.
|
||||
go p.acceptDataChannel()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptDataChannel waits for the client to open a data channel and reads Colibri messages.
|
||||
func (p *Participant) acceptDataChannel() {
|
||||
dc, err := datachannel.Accept(p.sctpAssoc, &datachannel.Config{
|
||||
LoggerFactory: p.loggerFactory,
|
||||
})
|
||||
if err != nil {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return // Expected during shutdown.
|
||||
default:
|
||||
p.log.Warnf("Accept data channel: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
p.dataChannel = dc
|
||||
p.log.Infof("Data channel accepted")
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, isString, err := dc.ReadDataChannel(buf)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return
|
||||
default:
|
||||
p.log.Debugf("Data channel read error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if isString {
|
||||
msg := string(buf[:n])
|
||||
p.log.Debugf("Colibri message: %s", msg)
|
||||
if p.onColibriMessage != nil {
|
||||
p.onColibriMessage(p.ID, msg)
|
||||
}
|
||||
} else {
|
||||
p.log.Debugf("Data channel binary message (%d bytes)", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetColibriCallback sets the callback for incoming Colibri data channel messages.
|
||||
func (p *Participant) SetColibriCallback(cb func(participantID int, msg string)) {
|
||||
p.onColibriMessage = cb
|
||||
}
|
||||
|
||||
// SetRTCPFeedbackCallback sets the callback for PLI/FIR RTCP feedback from this participant.
|
||||
func (p *Participant) SetRTCPFeedbackCallback(cb func(participantID int, mediaSSRC uint32, isFIR bool)) {
|
||||
p.onRTCPFeedback = cb
|
||||
}
|
||||
|
||||
// readRTCPLoop reads encrypted SRTCP packets from the demux RTCP channel,
|
||||
// decrypts them, parses for PLI/FIR, and invokes the feedback callback.
|
||||
func (p *Participant) readRTCPLoop() {
|
||||
rtcpCh := p.demux.RTCPChannel()
|
||||
decryptBuf := make([]byte, 8192)
|
||||
pktCount := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.closed:
|
||||
return
|
||||
case encrypted, ok := <-rtcpCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt SRTCP.
|
||||
decrypted, err := p.srtcpRemoteCtx.DecryptRTCP(decryptBuf[:0], encrypted, nil)
|
||||
if err != nil {
|
||||
pktCount++
|
||||
if pktCount <= 5 {
|
||||
p.log.Debugf("SRTCP decrypt error: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse RTCP compound packet.
|
||||
packets, err := rtcp.Unmarshal(decrypted)
|
||||
if err != nil {
|
||||
p.log.Debugf("RTCP unmarshal error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pkt := range packets {
|
||||
switch fb := pkt.(type) {
|
||||
case *rtcp.PictureLossIndication:
|
||||
p.log.Infof("Received PLI from participant %d for MediaSSRC=%d", p.ID, fb.MediaSSRC)
|
||||
if p.onRTCPFeedback != nil {
|
||||
p.onRTCPFeedback(p.ID, fb.MediaSSRC, false)
|
||||
}
|
||||
case *rtcp.FullIntraRequest:
|
||||
for _, entry := range fb.FIR {
|
||||
p.log.Infof("Received FIR from participant %d for SSRC=%d", p.ID, entry.SSRC)
|
||||
if p.onRTCPFeedback != nil {
|
||||
p.onRTCPFeedback(p.ID, entry.SSRC, true)
|
||||
}
|
||||
}
|
||||
case *rtcp.ReceiverEstimatedMaximumBitrate:
|
||||
bps := float64(fb.Bitrate)
|
||||
p.bwEstimator.OnREMB(bps)
|
||||
p.log.Debugf("REMB from participant %d: %.0f bps (smoothed=%.0f, effective=%.0f)",
|
||||
p.ID, bps, p.bwEstimator.SmoothedBps(), p.bwEstimator.EffectiveBps())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WriteRTCP sends a plaintext RTCP packet to this participant, encrypting it with
|
||||
// the local SRTCP context and writing directly to the ICE connection.
|
||||
func (p *Participant) WriteRTCP(data []byte) error {
|
||||
if p.srtcpLocalCtx == nil || p.iceConn == nil {
|
||||
return fmt.Errorf("SRTCP context or ICE conn not established")
|
||||
}
|
||||
if p.egressSim.IsPassthrough() {
|
||||
return p.writeRTCPDirect(data)
|
||||
}
|
||||
p.egressSim.Send(data, func(delayed []byte) {
|
||||
p.writeRTCPDirect(delayed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Participant) writeRTCPDirect(data []byte) error {
|
||||
p.srtcpMu.Lock()
|
||||
encrypted, err := p.srtcpLocalCtx.EncryptRTCP(nil, data, nil)
|
||||
p.srtcpMu.Unlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt SRTCP: %w", err)
|
||||
}
|
||||
_, err = p.iceConn.Write(encrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetRequestedLayer sets the video layer this receiver wants from a given sender.
|
||||
func (p *Participant) SetRequestedLayer(senderID int, layer int) {
|
||||
p.videoLayerMu.Lock()
|
||||
p.requestedLayers[senderID] = layer
|
||||
p.videoLayerMu.Unlock()
|
||||
}
|
||||
|
||||
// GetRequestedLayer returns the video layer this receiver wants from a given sender.
|
||||
// Returns -1 if no layer is requested (meaning: don't forward video from this sender).
|
||||
func (p *Participant) GetRequestedLayer(senderID int) int {
|
||||
p.videoLayerMu.RLock()
|
||||
defer p.videoLayerMu.RUnlock()
|
||||
if layer, ok := p.requestedLayers[senderID]; ok {
|
||||
return layer
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// SetSelectedLayer sets the video layer the SFU actually forwards from a given sender to this receiver.
|
||||
func (p *Participant) SetSelectedLayer(senderID int, layer int) {
|
||||
p.selectedLayerMu.Lock()
|
||||
p.selectedLayers[senderID] = layer
|
||||
p.selectedLayerMu.Unlock()
|
||||
}
|
||||
|
||||
// GetSelectedLayer returns the video layer the SFU forwards from a given sender to this receiver.
|
||||
// Returns -1 if no layer is selected (don't forward).
|
||||
func (p *Participant) GetSelectedLayer(senderID int) int {
|
||||
p.selectedLayerMu.RLock()
|
||||
defer p.selectedLayerMu.RUnlock()
|
||||
if layer, ok := p.selectedLayers[senderID]; ok {
|
||||
return layer
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// SendText sends a UTF-8 string message over the data channel.
|
||||
// Returns an error if the data channel is not yet established.
|
||||
func (p *Participant) SendText(msg string) error {
|
||||
dc := p.dataChannel
|
||||
if dc == nil {
|
||||
return fmt.Errorf("data channel not established")
|
||||
}
|
||||
_, err := dc.WriteDataChannel([]byte(msg), true)
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteRTP sends an encrypted RTP packet to this participant via the SRTP write stream.
|
||||
func (p *Participant) WriteRTP(pkt []byte) (int, error) {
|
||||
if p.srtpWriter == nil {
|
||||
return 0, fmt.Errorf("SRTP session not established")
|
||||
}
|
||||
if p.egressSim.IsPassthrough() {
|
||||
return p.srtpWriter.Write(pkt)
|
||||
}
|
||||
var n int
|
||||
var writeErr error
|
||||
p.egressSim.Send(pkt, func(delayed []byte) {
|
||||
n, writeErr = p.srtpWriter.Write(delayed)
|
||||
})
|
||||
return n, writeErr
|
||||
}
|
||||
|
||||
// AcceptStream blocks until a new SRTP read stream appears (new SSRC from client).
|
||||
// Returns the read stream and its SSRC.
|
||||
func (p *Participant) AcceptStream() (*srtp.ReadStreamSRTP, uint32, error) {
|
||||
if p.srtpSession == nil {
|
||||
return nil, 0, fmt.Errorf("SRTP session not established")
|
||||
}
|
||||
return p.srtpSession.AcceptStream()
|
||||
}
|
||||
|
||||
// Close tears down all transport layers in order.
|
||||
func (p *Participant) Close() error {
|
||||
var firstErr error
|
||||
p.once.Do(func() {
|
||||
close(p.closed)
|
||||
|
||||
if p.ingressSim != nil {
|
||||
p.ingressSim.Close()
|
||||
}
|
||||
if p.egressSim != nil {
|
||||
p.egressSim.Close()
|
||||
}
|
||||
|
||||
if p.dataChannel != nil {
|
||||
if err := p.dataChannel.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if p.sctpAssoc != nil {
|
||||
if err := p.sctpAssoc.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if p.srtpSession != nil {
|
||||
if err := p.srtpSession.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if p.dtlsConn != nil {
|
||||
if err := p.dtlsConn.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if p.demux != nil {
|
||||
if err := p.demux.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if p.iceConn != nil {
|
||||
if err := p.iceConn.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if p.iceAgent != nil {
|
||||
if err := p.iceAgent.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
|
||||
p.log.Infof("Participant %d closed", p.ID)
|
||||
})
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// formatFingerprint converts a hash byte slice to colon-separated uppercase hex.
|
||||
func formatFingerprint(hash []byte) string {
|
||||
result := make([]byte, 0, len(hash)*3-1)
|
||||
for i, b := range hash {
|
||||
if i > 0 {
|
||||
result = append(result, ':')
|
||||
}
|
||||
result = append(result, fmt.Sprintf("%02X", b)...)
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
+1240
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,262 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/rtcp"
|
||||
)
|
||||
|
||||
// --- RTP Header Extension Parsing ---
|
||||
|
||||
// parseTWCCSeq extracts the transport-wide sequence number from an RTP packet.
|
||||
// extID is the header extension ID to look for (typically 3).
|
||||
// Returns the sequence number and true if found, or 0 and false.
|
||||
func parseTWCCSeq(pkt []byte, extID int) (uint16, bool) {
|
||||
if len(pkt) < 12 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Check extension bit (X) in RTP header byte 0.
|
||||
if pkt[0]&0x10 == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Skip fixed header (12 bytes) + CSRC list.
|
||||
cc := int(pkt[0] & 0x0F)
|
||||
offset := 12 + cc*4
|
||||
if offset+4 > len(pkt) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Check for one-byte header extension (0xBEDE magic).
|
||||
if pkt[offset] != 0xBE || pkt[offset+1] != 0xDE {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Extension length in 32-bit words.
|
||||
extLen := int(binary.BigEndian.Uint16(pkt[offset+2:])) * 4
|
||||
offset += 4
|
||||
extEnd := offset + extLen
|
||||
if extEnd > len(pkt) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Scan extension elements: [id:4][len:4][data...].
|
||||
for offset < extEnd {
|
||||
b := pkt[offset]
|
||||
if b == 0 {
|
||||
// Padding byte.
|
||||
offset++
|
||||
continue
|
||||
}
|
||||
id := int(b >> 4)
|
||||
dataLen := int(b&0x0F) + 1 // len field is 0-based
|
||||
offset++
|
||||
if id == extID && dataLen >= 2 && offset+2 <= extEnd {
|
||||
seq := binary.BigEndian.Uint16(pkt[offset:])
|
||||
return seq, true
|
||||
}
|
||||
offset += dataLen
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// --- Transport-CC Feedback Generator ---
|
||||
|
||||
type twccArrival struct {
|
||||
seq uint16
|
||||
arrivalUs int64 // microseconds since generator creation
|
||||
}
|
||||
|
||||
// TransportCCGenerator generates RTCP transport-cc feedback for a sender.
|
||||
// It tracks packet arrivals and emits feedback every 100ms.
|
||||
type TransportCCGenerator struct {
|
||||
mu sync.Mutex
|
||||
arrivals []twccArrival
|
||||
startTime time.Time
|
||||
fbCount uint8 // feedback packet counter
|
||||
|
||||
// Callback to send the feedback RTCP packet.
|
||||
sendFeedback func(data []byte)
|
||||
|
||||
stopCh chan struct{}
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewTransportCCGenerator creates and starts a generator.
|
||||
// sendFeedback is called with marshalled+encrypted RTCP data to send to the sender.
|
||||
func NewTransportCCGenerator(sendFeedback func(data []byte)) *TransportCCGenerator {
|
||||
g := &TransportCCGenerator{
|
||||
startTime: time.Now(),
|
||||
sendFeedback: sendFeedback,
|
||||
stopCh: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go g.run()
|
||||
return g
|
||||
}
|
||||
|
||||
// RecordArrival records a packet arrival. Thread-safe.
|
||||
func (g *TransportCCGenerator) RecordArrival(twccSeq uint16) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
arrivalUs := time.Since(g.startTime).Microseconds()
|
||||
g.arrivals = append(g.arrivals, twccArrival{seq: twccSeq, arrivalUs: arrivalUs})
|
||||
}
|
||||
|
||||
// Stop terminates the generator.
|
||||
func (g *TransportCCGenerator) Stop() {
|
||||
close(g.stopCh)
|
||||
<-g.done
|
||||
}
|
||||
|
||||
func (g *TransportCCGenerator) run() {
|
||||
defer close(g.done)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-g.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
g.emitFeedback()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *TransportCCGenerator) emitFeedback() {
|
||||
g.mu.Lock()
|
||||
if len(g.arrivals) == 0 {
|
||||
g.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Take all arrivals.
|
||||
arrivals := g.arrivals
|
||||
g.arrivals = nil
|
||||
g.fbCount++
|
||||
fbCount := g.fbCount
|
||||
g.mu.Unlock()
|
||||
|
||||
// Sort by sequence number (should already be mostly sorted).
|
||||
for i := 1; i < len(arrivals); i++ {
|
||||
for j := i; j > 0 && seqBefore(arrivals[j].seq, arrivals[j-1].seq); j-- {
|
||||
arrivals[j], arrivals[j-1] = arrivals[j-1], arrivals[j]
|
||||
}
|
||||
}
|
||||
|
||||
baseSeq := arrivals[0].seq
|
||||
// Number of sequence numbers covered (including gaps).
|
||||
lastSeq := arrivals[len(arrivals)-1].seq
|
||||
packetCount := seqDiff(baseSeq, lastSeq) + 1
|
||||
|
||||
// Reference time: arrival of first packet in 64ms units.
|
||||
refTimeUs := arrivals[0].arrivalUs
|
||||
refTime := uint32(refTimeUs / 64000) // 64ms units, 24-bit in spec but stored as uint32
|
||||
|
||||
// Build received set for gap detection.
|
||||
receivedAt := make(map[uint16]int64, len(arrivals))
|
||||
for _, a := range arrivals {
|
||||
receivedAt[a.seq] = a.arrivalUs
|
||||
}
|
||||
|
||||
// Build packet chunks and recv deltas.
|
||||
var chunks []rtcp.PacketStatusChunk
|
||||
var deltas []*rtcp.RecvDelta
|
||||
|
||||
// Process in runs of up to 7 (status vector chunk capacity for 2-bit symbols).
|
||||
prevArrivalUs := refTimeUs
|
||||
var statusList []uint16
|
||||
|
||||
seq := baseSeq
|
||||
for i := 0; i < int(packetCount); i++ {
|
||||
arrUs, received := receivedAt[seq]
|
||||
if received {
|
||||
deltaUs := arrUs - prevArrivalUs
|
||||
if deltaUs >= 0 && deltaUs <= 63750 { // fits in small delta (0-255 * 250us)
|
||||
statusList = append(statusList, rtcp.TypeTCCPacketReceivedSmallDelta)
|
||||
deltas = append(deltas, &rtcp.RecvDelta{
|
||||
Type: rtcp.TypeTCCPacketReceivedSmallDelta,
|
||||
Delta: deltaUs,
|
||||
})
|
||||
} else {
|
||||
statusList = append(statusList, rtcp.TypeTCCPacketReceivedLargeDelta)
|
||||
deltas = append(deltas, &rtcp.RecvDelta{
|
||||
Type: rtcp.TypeTCCPacketReceivedLargeDelta,
|
||||
Delta: deltaUs,
|
||||
})
|
||||
}
|
||||
prevArrivalUs = arrUs
|
||||
} else {
|
||||
statusList = append(statusList, rtcp.TypeTCCPacketNotReceived)
|
||||
}
|
||||
seq++
|
||||
}
|
||||
|
||||
// Encode status list as status vector chunks (7 symbols per chunk with 2-bit symbols).
|
||||
for i := 0; i < len(statusList); i += 7 {
|
||||
end := i + 7
|
||||
if end > len(statusList) {
|
||||
end = len(statusList)
|
||||
}
|
||||
chunk := statusList[i:end]
|
||||
|
||||
// Check if all same status (use run-length).
|
||||
allSame := true
|
||||
for _, s := range chunk {
|
||||
if s != chunk[0] {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allSame && len(chunk) >= 2 {
|
||||
chunks = append(chunks, &rtcp.RunLengthChunk{
|
||||
Type: rtcp.TypeTCCRunLengthChunk,
|
||||
PacketStatusSymbol: chunk[0],
|
||||
RunLength: uint16(len(chunk)),
|
||||
})
|
||||
} else {
|
||||
// Status vector with 2-bit symbols.
|
||||
symbolList := make([]uint16, len(chunk))
|
||||
copy(symbolList, chunk)
|
||||
chunks = append(chunks, &rtcp.StatusVectorChunk{
|
||||
Type: rtcp.TypeTCCStatusVectorChunk,
|
||||
SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit,
|
||||
SymbolList: symbolList,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fb := &rtcp.TransportLayerCC{
|
||||
SenderSSRC: 1,
|
||||
MediaSSRC: 0,
|
||||
BaseSequenceNumber: baseSeq,
|
||||
PacketStatusCount: packetCount,
|
||||
ReferenceTime: refTime,
|
||||
FbPktCount: fbCount,
|
||||
PacketChunks: chunks,
|
||||
RecvDeltas: deltas,
|
||||
}
|
||||
|
||||
data, err := rtcp.Marshal([]rtcp.Packet{fb})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
g.sendFeedback(data)
|
||||
}
|
||||
|
||||
// seqBefore returns true if a comes before b in the uint16 sequence space.
|
||||
func seqBefore(a, b uint16) bool {
|
||||
return int16(a-b) < 0
|
||||
}
|
||||
|
||||
// seqDiff returns the forward distance from a to b in uint16 sequence space.
|
||||
func seqDiff(a, b uint16) uint16 {
|
||||
return b - a
|
||||
}
|
||||
Reference in New Issue
Block a user