feat: add testbench (CLI tool, Go SFU, Dockerfile, docs)

Brings the testbench source into the submodule:
- tools/cli/ — C++ CLI test tool (P2P, reflector, group, group-churn modes)
- tools/go_sfu/ — Go/Pion SFU library, c-archive linked into tgcalls_cli
- Dockerfile — multi-stage Linux container build
- CLAUDE.md (top-level), tools/cli/CLAUDE.md, tools/go_sfu/CLAUDE.md — docs

Bazel build glue (.bazelrc, MODULE.bazel, third-party BUILD edits, tgcalls_core
target) remains in the outer repo since the dependency stack lives there;
labels and paths in this repo reference the outer-repo workspace root.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
isaac
2026-04-30 19:00:55 +02:00
parent 1f119d8f32
commit aaa583f1a9
28 changed files with 5996 additions and 0 deletions
+202
View File
@@ -0,0 +1,202 @@
# CLAUDE.md
This is a testbench repository for the tgcalls VoIP library (from Telegram). It contains the full Telegram iOS source tree as a build dependency, but the focus is on testing and debugging tgcalls.
## Build
Requires Bazel 8.4.2 (download to `build-input/` if not present):
```bash
# One-time setup: create build configuration stub
mkdir -p build-input/configuration-repository/provisioning
# Then populate MODULE.bazel, BUILD, variables.bzl, provisioning/BUILD
# (see build-input/configuration-repository/ for existing stubs)
# Build the CLI test tool
./build-input/bazel-8.4.2 build //submodules/TgVoipWebrtc/tgcalls/tools/cli:tgcalls_cli
```
The system-installed Bazel (v9) is NOT compatible with this codebase.
## Linux Build
Prerequisites (Ubuntu/Debian):
```bash
apt install gcc g++ cmake meson ninja-build nasm make autoconf automake libtool pkg-config zlib1g-dev libbz2-dev
```
Download the Linux Bazel 8.4.2 binary to `build-input/`:
```bash
curl -fL "https://github.com/bazelbuild/bazel/releases/download/8.4.2/bazel-8.4.2-linux-arm64" -o build-input/bazel-8.4.2-linux
chmod +x build-input/bazel-8.4.2-linux
```
Build the CLI test tool:
```bash
./build-input/bazel-8.4.2-linux build //submodules/TgVoipWebrtc/tgcalls/tools/cli:tgcalls_cli
```
The same Bazel 8.4.2 version is required. The build uses the system GCC toolchain and system-installed cmake/meson/ninja for third-party library compilation.
## Docker Build
Build a minimal Linux container image from macOS (or any Docker host):
```bash
# Build (uses BuildKit cache — first build ~5 min, rebuilds seconds)
docker build -t tgcalls-test .
# Run locally
docker run --rm tgcalls-test --mode p2p --duration 5 --quiet
docker run --rm tgcalls-test --mode reflector --reflector 91.108.13.2:598 --duration 10 --quiet
# Push to ECR for AWS deployment
docker tag tgcalls-test 654654616143.dkr.ecr.eu-west-1.amazonaws.com/tgcalls-test:latest
docker push 654654616143.dkr.ecr.eu-west-1.amazonaws.com/tgcalls-test:latest
```
The Dockerfile uses a multi-stage build: full build environment in stage 1, minimal runtime image (~50MB) in stage 2. Bazel's build cache is preserved across `docker build` invocations via `--mount=type=cache`. The image is built for ARM64 (matches Apple Silicon and Fargate ARM).
## Testing
### Local Mass Testing
Run large-scale P2P tests locally using `run-local-test.sh`. Launches N parallel processes, each running a single call, and aggregates results.
```bash
# 1000 calls, 150 parallel, 30% loss (default settings)
./submodules/TgVoipWebrtc/tgcalls/tools/cli/run-local-test.sh -n 1000
# Custom parallelism and duration
./submodules/TgVoipWebrtc/tgcalls/tools/cli/run-local-test.sh -n 500 -j 100 -d 30
# Custom loss parameters
./submodules/TgVoipWebrtc/tgcalls/tools/cli/run-local-test.sh -n 1000 --drop-rate 0.5 --delay 100-300
```
Options: `-n NUM` (count), `-j PARALLEL` (default 150), `-d DURATION` (default 15s), `--drop-rate RATE` (default 0.3), `--delay MIN-MAX` (default 50-200), `--mode MODE` (default p2p), `--version VER` (default 13.0.0).
Typical results: 100% success rate at 30% loss on Apple Silicon (16 cores).
### AWS Mass Testing
Run large-scale reflector tests on ECS Fargate (ARM64). Infrastructure is pre-configured in eu-west-1. Requires Docker push first.
```bash
# Launch 1000 tasks across all Telegram reflectors, 30s each
./submodules/TgVoipWebrtc/tgcalls/tools/cli/run-test.sh -n 1000 -d 30
# Collect results
./submodules/TgVoipWebrtc/tgcalls/tools/cli/run-test.sh --results
```
The script fetches the reflector list from `https://core.telegram.org/getReflectorList`, embeds the IPs as a `--reflector-list` argument (each task picks a random IP + random port 596-599), and launches in waves of 500 (Fargate concurrent task limit). Results are collected from CloudWatch Logs with automatic retry for delayed log delivery.
**AWS resources** (eu-west-1, account 654654616143):
- ECR: `tgcalls-test`
- ECS cluster: `tgcalls-test`
- Task definition: `tgcalls-test` (ARM64 Fargate, 0.25 vCPU, 512MB)
- CloudWatch log group: `/ecs/tgcalls-test`
- Subnets: `subnet-0292f49f3b4885428`, `subnet-09b8edab6eb20b837`, `subnet-0f464b5c62c9a6d1a`
- Security group: `sg-0d87a1f19be76c160`
**Cost**: ~$0.01 per 100 tasks (~$0.10 per 1000-task run).
## tgcalls CLI Test Tool
Located at `submodules/TgVoipWebrtc/tgcalls/tools/cli/`. Runs tgcalls instances in-process with emulated signaling and validates audio/media flow.
```bash
# P2P mode (direct loopback, no network)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode p2p --duration 10
# Reflector mode (routes through a real Telegram UDP reflector)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode reflector --reflector 91.108.13.2:596 --duration 10
# Random reflector from a list (picks one at random, randomizes port 596-599 if no port given)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --reflector-list "91.108.13.2,91.108.13.3,91.108.9.1" --duration 10
# Simulate lossy signaling (30% drop, 50-200ms random delay)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode p2p --duration 30 --drop-rate 0.3 --delay 50-200
# Quiet mode (summary only, full tgcalls logs dumped on failure)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode p2p --duration 5 --quiet
# Group mode (in-process SFU with N participants)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 3 --duration 10
# Mixed group mode (CustomImpl + ReferenceImpl participants)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 2 --reference-participants 2 --duration 15
# Group mode with video (H264 simulcast, pattern generator)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 2 --video --duration 15
# Mixed group with video (both CustomImpl and ReferenceImpl send/receive video)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 2 --reference-participants 2 --video --duration 15
# ReferenceImpl-only video
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 0 --reference-participants 3 --video --duration 15
# Group churn stress test (100 join/leave cycles, then validate base group)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group-churn --participants 3 --duration 10
# Group churn with video
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group-churn --participants 3 --video --churn-cycles 100 --duration 10
# Mixed implementations churn
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group-churn --participants 2 --reference-participants 1 --video --duration 10
```
`--mode` is required (`p2p`, `reflector`, `group`, or `group-churn`) unless `--reflector-list` is used (implies reflector mode). Exit code 0 = success. Exit code 1 = failure.
For p2p/reflector: success = call established, stats logs non-empty, BWE non-zero for both sides.
For group (audio): success = all N participants report `isConnected = true` AND all participants receive remote audio (non-zero SSRC with level > 0.05 via `audioLevelsUpdated`). Remote 440Hz sine tone typically arrives at ~0.126 level.
For group (video): audio criteria plus every participant receives ≥1 decoded video frame from every other participant via `FakeVideoSink` frame counting.
For group-churn: success = all churn cycles complete without crash/hang AND base group passes group validation (all connected, all receiving audio, and if video, all receiving video from all other base participants).
### CLI Options
- `--mode p2p|reflector|group|group-churn` — call mode (required unless `--reflector-list` used)
- `--reflector host:port` — single reflector address
- `--reflector-list addr,addr,...` — comma-separated list, one picked at random
- `--version VER` — caller tgcalls protocol version (default: `13.0.0`)
- `--version2 VER` — callee tgcalls protocol version (default: same as `--version`). Enables cross-version interop testing.
- `--participants N` — number of CustomImpl participants in group mode (default: 3)
- `--reference-participants N` — number of ReferenceImpl (PeerConnection-based) participants in group mode (default: 0). Total = `--participants` + `--reference-participants`.
- `--duration N` — test duration in seconds (default: 10)
- `--drop-rate 0.0-1.0` — signaling packet drop probability
- `--delay min-max` — signaling delay range in ms (e.g., `50-200`)
- `--video` — enable H264 video with simulcast in group mode (both CustomImpl and ReferenceImpl participants)
- `--churn-cycles N` — number of join/leave cycles in group-churn mode (default: 100)
- `--network-scenario NAME` — network simulation test scenario (e.g., `step-down-up`). Group mode only.
- `--quiet` — summary output only
### Modes
- **P2P**: Direct loopback, `enableP2P=true`, no servers configured
- **Reflector**: Routes through a Telegram UDP reflector, `enableP2P=false`, configures `RtcServer` with `login="reflector"` and random peer tags (16 bytes, byte 0 = `0x00` for caller, `0x01` for callee)
- **Group**: In-process SFU with N participants using `GroupInstanceCustomImpl` and/or `GroupInstanceReferenceImpl`. The SFU is implemented in Go using Pion's low-level ICE/DTLS/SRTP/SCTP APIs (not PeerConnection), linked into the same process via CGo c-archive. Each participant gets a full ICE + DTLS + SRTP + SCTP transport stack over localhost UDP. Audio RTP is selectively forwarded between all participants. With `--video`, H264 video with 3-layer simulcast is enabled. Mixed-implementation groups (CustomImpl + ReferenceImpl) are supported via `--reference-participants`.
- **Group Churn**: Stress test for participant join/leave dynamics. Creates a base group of N participants, then rapidly cycles an additional participant in and out `--churn-cycles` times (default 100). After churn, validates that the base group is healthy: all connected, all receiving audio, and if `--video` is enabled, all receiving video. Alternates between CustomImpl and ReferenceImpl for the cycling participant. The `--duration` controls the stabilization wait after churn completes.
## Project Structure
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/` — CLI test tool (main.cpp, group_mode.cpp, group_participant.h/.cpp, group_churn_mode.h/.cpp, fake_video_source.h/.cpp, fake_video_sink.h, run-test.sh, run-local-test.sh, BUILD)
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/` — Go/Pion SFU library (sfu.go, participant.go, mux.go, go.mod/go.sum), built as c-archive via rules_go + Gazelle, linked into tgcalls_cli
- `submodules/TgVoipWebrtc/tgcalls/tgcalls/` — tgcalls library source
- `submodules/TgVoipWebrtc/tgcalls/tgcalls/group/` — group call implementations (GroupInstanceCustomImpl, GroupInstanceReferenceImpl, GroupNetworkManager, GroupJoinPayloadInternal)
- `submodules/TgVoipWebrtc/tgcalls/tgcalls/v2/` — v2 implementation (InstanceV2Impl, InstanceV2ReferenceImpl, InstanceV2CompatImpl, NativeNetworkingImpl, SignalingSctpConnection, SignalingTranslator)
- `submodules/TgVoipWebrtc/BUILD` — contains `tgcalls_core` target (C++ only, macOS-native) and `TgVoipWebrtc` target (iOS, ObjC)
- `third-party/webrtc/` — WebRTC source and BUILD
- `third-party/webrtc/webrtc/net/dcsctp/` — dc-sctp (SCTP implementation)
- `third-party/webrtc/webrtc/media/sctp/dcsctp_transport.cc` — WebRTC SCTP wrapper
- `third-party/` — other dependencies (opus, libvpx, ffmpeg, boringssl, etc.)
## Code Style
- **Naming**: PascalCase for types, camelCase for variables/methods
- **Language**: C++17 for tgcalls code
- **Formatting**: Standard C++ formatting
## Further Context
When working in these areas, additional `CLAUDE.md` files load automatically:
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/CLAUDE.md` — CLI test tool architecture (P2P/Reflector, Group), supported version matrix
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/CLAUDE.md` — Go SFU internals: build integration, bandwidth adaptation, transport-cc feedback, network simulation
- `submodules/TgVoipWebrtc/CLAUDE.md` — tgcalls library internals: macOS/Linux build patches, SCTP signaling, InstanceV2CompatImpl, GroupInstanceCustomImpl/ReferenceImpl, video pitfalls, known issues
+52
View File
@@ -0,0 +1,52 @@
# syntax=docker/dockerfile:1
# Multi-stage build for tgcalls_cli Linux container
# Build: docker build -t tgcalls-test .
# Run: docker run tgcalls-test --mode reflector --reflector 91.108.13.2:598 --duration 10
# ============================================================
# Stage 1: Build
# ============================================================
FROM ubuntu:24.04 AS builder
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc g++ cmake meson ninja-build nasm make \
autoconf automake libtool pkg-config python3 \
unzip curl ca-certificates patch \
zlib1g-dev libbz2-dev \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /src
# Copy source tree
COPY . .
# Always download Bazel for the container's architecture (host copy may be wrong arch)
RUN ARCH=$(uname -m) && \
if [ "$ARCH" = "x86_64" ]; then BAZEL_ARCH="x86_64"; \
elif [ "$ARCH" = "aarch64" ]; then BAZEL_ARCH="arm64"; \
else echo "Unsupported arch: $ARCH" && exit 1; fi && \
curl -fL "https://github.com/bazelbuild/bazel/releases/download/8.4.2/bazel-8.4.2-linux-${BAZEL_ARCH}" \
-o build-input/bazel-8.4.2-linux && \
chmod +x build-input/bazel-8.4.2-linux
# Build with persistent Bazel cache
RUN --mount=type=cache,target=/root/.cache/bazel \
./build-input/bazel-8.4.2-linux build //submodules/TgVoipWebrtc/tgcalls/tools/cli:tgcalls_cli \
--strategy=Genrule=standalone --spawn_strategy=standalone && \
cp bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli /tmp/tgcalls_cli
# ============================================================
# Stage 2: Runtime (minimal)
# ============================================================
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /tmp/tgcalls_cli /usr/local/bin/tgcalls_cli
ENTRYPOINT ["tgcalls_cli"]
CMD ["--help"]
+37
View File
@@ -0,0 +1,37 @@
cc_binary(
name = "tgcalls_cli",
srcs = [
"main.cpp",
"group_mode.cpp",
"group_mode.h",
"group_participant.cpp",
"group_participant.h",
"group_churn_mode.cpp",
"group_churn_mode.h",
"fake_video_source.h",
"fake_video_source.cpp",
"fake_video_sink.h",
],
copts = [
"-I{}/tgcalls/tgcalls".format("submodules/TgVoipWebrtc"),
"-Ithird-party/webrtc/webrtc",
"-Ithird-party/webrtc/dependencies",
"-Ithird-party/webrtc/absl",
"-Ithird-party/libyuv",
"-DRTC_ENABLE_VP9",
"-DNDEBUG",
"-std=c++17",
"-w",
] + select({
"@platforms//os:linux": ["-DWEBRTC_LINUX", "-DWEBRTC_POSIX"],
"//conditions:default": ["-DWEBRTC_MAC", "-DWEBRTC_POSIX"],
}),
linkopts = select({
"@platforms//os:linux": ["-lpthread", "-lm", "-ldl"],
"//conditions:default": ["-framework", "CoreFoundation", "-framework", "Security"],
}),
deps = [
"//submodules/TgVoipWebrtc:tgcalls_core",
"//submodules/TgVoipWebrtc/tgcalls/tools/go_sfu",
],
)
+41
View File
@@ -0,0 +1,41 @@
# tgcalls CLI Test Tool
In-process test harness for tgcalls. See the root `CLAUDE.md` for build instructions, top-level CLI usage, and the CLI options reference.
## Supported Versions
| Version | Implementation | Notes |
|---|---|---|
| `14.0.0` | `InstanceV2CompatImpl` | WebRTC PeerConnection + V2Impl signaling. Cross-version interop with 7.0.013.0.0 |
| `13.0.0` (default) | `InstanceV2Impl` | Also: 7.0.0, 8.0.0, 9.0.0, 12.0.0 |
| `11.0.0` | `InstanceV2ReferenceImpl` | Also: 10.0.0. Uses WebRTC PeerConnection |
| `5.0.0` | `InstanceImpl` (v1) | Also: 2.7.7. Legacy |
## Architecture (P2P/Reflector)
- Two `tgcalls::Instance` objects (caller + callee) created via `Meta::Create(version, ...)`
- Signaling bridged via `SignalingBridge` with configurable drop rate and delay
- `FakeAudioDeviceModule` with `SineRecorder` (440Hz tone) and `NoOpRenderer` (audio discarded; validation via BWE)
- `FakeInterface` platform implementation (pure C++, no iOS/ObjC deps)
- Stats log validation: both caller and callee write `config.statsLogPath` with bitrate records; non-empty log with at least one non-zero BWE value is a success condition
- On failure, full tgcalls internal logs (caller + callee) are dumped to stdout via `config.logPath`
## Architecture (Group)
- N participants using `GroupInstanceCustomImpl` and/or `GroupInstanceReferenceImpl` connect to an in-process Go SFU
- SFU uses Pion's low-level APIs (pion/ice, pion/dtls, pion/srtp, pion/sctp) — NOT PeerConnection
- ICE: lite mode, loopback-only, UDP host candidates on 127.0.0.1. SFU uses `Dial` (controlling) for CustomImpl clients and `Accept` (controlled) for PeerConnection clients
- DTLS: SFU acts as DTLS client (setup=active); GroupNetworkManager hardcodes SSL_SERVER for the tgcalls client
- SRTP: AES-256-GCM (negotiated via DTLS-SRTP; GroupNetworkManager requires GCM suites)
- SCTP: over DTLS, accepts data channel from client, reads Colibri messages, sends `ActiveAudioSsrcs` and `ActiveVideoSsrcs` notifications
- RTP forwarding: audio RTP forwarded to all others unconditionally; video RTP forwarded only to receivers that have requested video from that sender (via `ReceiverVideoConstraints`)
- SSRC tracking: SFU maintains `ssrcRegistry map[uint32]ssrcInfo` with kind (audio/video/video-rtx) and simulcast layer index, exposed via `GoSfu_QuerySsrc` and `GoSfu_QueryVideoSsrcs` CGo exports
- SSRC discovery: SFU broadcasts `ActiveAudioSsrcs` and `ActiveVideoSsrcs` over data channel when participants connect
- Video SSRC groups: parsed from join payload `"ssrc-groups"` field (SIM + FID semantics), stored per participant
- Colibri video constraints: SFU parses `ReceiverVideoConstraints` from receivers, sends `SenderVideoConstraints` back to senders with `idealHeight`, and sends proactive PLI to trigger keyframes when a receiver first requests video
- RTCP feedback: SFU demuxes SRTCP from the shared ICE transport (RFC 5761: byte[1] >= 200 && < 224), decrypts with per-participant SRTCP contexts, parses PLI/FIR, and forwards as new PLI to the sender. NACK is terminated (not forwarded).
- Audio validation: `audioLevelsUpdated` callback tracks remote audio levels; success requires every participant to receive audio from at least one other participant (remote SSRC != 0, level > 0.05). The 440Hz sine tone arrives at ~0.126 level after SFU forwarding.
- Video validation: `FakeVideoSink` (implements `rtc::VideoSinkInterface<VideoFrame>`) counts decoded frames per remote endpoint; success requires every participant to receive ≥1 frame from every other
- Video signaling flow: SFU broadcasts `ActiveVideoSsrcs` over data channel → `dataChannelMessageReceived` callback fires in the app → app calls `setRequestedVideoChannels` → CustomImpl creates `IncomingVideoChannel` / ReferenceImpl adds recvonly video transceiver → both send `ReceiverVideoConstraints` → SFU sends `SenderVideoConstraints` + proactive PLI → sender produces keyframe → receiver decodes
- `dataChannelMessageReceived` callback: added to `GroupInstanceDescriptor`, forwards all incoming Colibri data channel messages to the application. Used by the CLI test tool to react to `ActiveVideoSsrcs` and dynamically set up video channels — mirrors the real Telegram app's reactive flow
- `FakeAudioDeviceModule` with `SineRecorder` (440Hz tone) and `NoOpRenderer` — same as P2P mode
- `FakeVideoTrackSource` generates 1280x720 I420 frames at 30fps with per-participant color tint and frame counter (720p needed for 3 simulcast layers; 640x360 only allows 2 per WebRTC's `kSimulcastFormats`)
- Group mode source: `submodules/TgVoipWebrtc/tgcalls/tools/cli/group_mode.cpp`
+24
View File
@@ -0,0 +1,24 @@
#pragma once
#include "api/video/video_frame.h"
#include "api/video/video_sink_interface.h"
#include <atomic>
class FakeVideoSink : public rtc::VideoSinkInterface<webrtc::VideoFrame> {
public:
void OnFrame(const webrtc::VideoFrame& frame) override {
frameCount_.fetch_add(1, std::memory_order_relaxed);
int w = frame.width();
int h = frame.height();
lastWidth_.store(w, std::memory_order_relaxed);
lastHeight_.store(h, std::memory_order_relaxed);
}
int lastWidth() const { return lastWidth_.load(std::memory_order_relaxed); }
int lastHeight() const { return lastHeight_.load(std::memory_order_relaxed); }
int frameCount() const {
return frameCount_.load(std::memory_order_relaxed);
}
private:
std::atomic<int> frameCount_{0};
std::atomic<int> lastWidth_{0};
std::atomic<int> lastHeight_{0};
};
+149
View File
@@ -0,0 +1,149 @@
#include "fake_video_source.h"
#include "api/video/video_frame.h"
#include "api/video/video_rotation.h"
#include "rtc_base/time_utils.h"
#include <cstring>
namespace {
constexpr int kWidth = 1280;
constexpr int kHeight = 720;
constexpr int kFps = 30;
constexpr uint8_t kBgY = 80; // dark background
constexpr uint8_t kDigitY = 235; // white digits
constexpr int kDigitW = 5;
constexpr int kDigitH = 7;
constexpr int kScale = 4;
constexpr int kDigitSpacing = 2; // pixels between digits (scaled)
constexpr int kMargin = 8; // top-left margin in pixels
// 5x7 bitmap font for digits 0-9. Each entry is 7 rows of 5-bit patterns.
// MSB = leftmost pixel.
static const uint8_t kDigitBitmaps[10][7] = {
// 0
{0b01110, 0b10001, 0b10011, 0b10101, 0b11001, 0b10001, 0b01110},
// 1
{0b00100, 0b01100, 0b00100, 0b00100, 0b00100, 0b00100, 0b01110},
// 2
{0b01110, 0b10001, 0b00001, 0b00010, 0b00100, 0b01000, 0b11111},
// 3
{0b11111, 0b00010, 0b00100, 0b00010, 0b00001, 0b10001, 0b01110},
// 4
{0b00010, 0b00110, 0b01010, 0b10010, 0b11111, 0b00010, 0b00010},
// 5
{0b11111, 0b10000, 0b11110, 0b00001, 0b00001, 0b10001, 0b01110},
// 6
{0b00110, 0b01000, 0b10000, 0b11110, 0b10001, 0b10001, 0b01110},
// 7
{0b11111, 0b00001, 0b00010, 0b00100, 0b01000, 0b01000, 0b01000},
// 8
{0b01110, 0b10001, 0b10001, 0b01110, 0b10001, 0b10001, 0b01110},
// 9
{0b01110, 0b10001, 0b10001, 0b01111, 0b00001, 0b00010, 0b01100},
};
// 6 color tints cycling: red, green, blue, yellow, cyan, magenta
// UV values for each tint (in I420, U=Cb, V=Cr; neutral=128)
struct UVTint { uint8_t u; uint8_t v; };
static const UVTint kTints[6] = {
{90, 240}, // red
{54, 34}, // green
{240, 110}, // blue
{16, 146}, // yellow
{166, 16}, // cyan
{166, 240}, // magenta
};
} // namespace
FakeVideoTrackSource::FakeVideoTrackSource(int participantId)
: participantId_(participantId) {
const auto& tint = kTints[participantId % 6];
uTint_ = tint.u;
vTint_ = tint.v;
thread_ = std::thread(&FakeVideoTrackSource::GenerateThread, this);
}
FakeVideoTrackSource::~FakeVideoTrackSource() {
Stop();
}
rtc::scoped_refptr<FakeVideoTrackSource> FakeVideoTrackSource::Create(int participantId) {
return rtc::scoped_refptr<FakeVideoTrackSource>(
new rtc::RefCountedObject<FakeVideoTrackSource>(participantId));
}
void FakeVideoTrackSource::Stop() {
if (running_.exchange(false)) {
if (thread_.joinable()) {
thread_.join();
}
}
}
void FakeVideoTrackSource::GenerateThread() {
int frameNumber = 0;
while (running_.load(std::memory_order_relaxed)) {
auto buffer = webrtc::I420Buffer::Create(kWidth, kHeight);
// Fill Y plane with dark background
memset(buffer->MutableDataY(), kBgY, buffer->StrideY() * kHeight);
// Fill U plane with tint
int uvHeight = (kHeight + 1) / 2;
memset(buffer->MutableDataU(), uTint_, buffer->StrideU() * uvHeight);
// Fill V plane with tint
memset(buffer->MutableDataV(), vTint_, buffer->StrideV() * uvHeight);
// Render frame counter digits
RenderDigits(buffer->MutableDataY(), buffer->StrideY(), frameNumber);
auto frame = webrtc::VideoFrame::Builder()
.set_video_frame_buffer(buffer)
.set_rotation(webrtc::kVideoRotation_0)
.set_timestamp_us(rtc::TimeMicros())
.build();
OnFrame(frame);
++frameNumber;
std::this_thread::sleep_for(std::chrono::milliseconds(1000 / kFps));
}
}
void FakeVideoTrackSource::RenderDigits(uint8_t* yPlane, int strideY, int frameNumber) {
// Convert frame number to decimal digits
char numStr[16];
snprintf(numStr, sizeof(numStr), "%d", frameNumber);
int numDigits = static_cast<int>(strlen(numStr));
int xOffset = kMargin;
for (int d = 0; d < numDigits; ++d) {
int digit = numStr[d] - '0';
const uint8_t* bitmap = kDigitBitmaps[digit];
for (int row = 0; row < kDigitH; ++row) {
uint8_t rowBits = bitmap[row];
for (int col = 0; col < kDigitW; ++col) {
if (rowBits & (1 << (kDigitW - 1 - col))) {
// Fill scaled pixel block
int px = xOffset + col * kScale;
int py = kMargin + row * kScale;
for (int sy = 0; sy < kScale; ++sy) {
for (int sx = 0; sx < kScale; ++sx) {
int x = px + sx;
int y = py + sy;
if (x < kWidth && y < kHeight) {
yPlane[y * strideY + x] = kDigitY;
}
}
}
}
}
}
xOffset += kDigitW * kScale + kDigitSpacing;
}
}
+38
View File
@@ -0,0 +1,38 @@
#pragma once
#include "api/video/i420_buffer.h"
#include "media/base/adapted_video_track_source.h"
#include "rtc_base/ref_counted_object.h"
#include <atomic>
#include <thread>
// Generates 640x360 I420 frames at 30fps with per-participant color tint
// and an incrementing frame counter rendered as block digits.
class FakeVideoTrackSource : public rtc::AdaptedVideoTrackSource {
public:
static rtc::scoped_refptr<FakeVideoTrackSource> Create(int participantId);
~FakeVideoTrackSource() override;
void Stop();
// VideoTrackSourceInterface
SourceState state() const override { return kLive; }
bool remote() const override { return false; }
bool is_screencast() const override { return false; }
absl::optional<bool> needs_denoising() const override { return false; }
protected:
explicit FakeVideoTrackSource(int participantId);
private:
void GenerateThread();
void RenderDigits(uint8_t* yPlane, int strideY, int frameNumber);
int participantId_;
uint8_t uTint_;
uint8_t vTint_;
std::atomic<bool> running_{true};
std::thread thread_;
};
+205
View File
@@ -0,0 +1,205 @@
#include "group_churn_mode.h"
#include "group_participant.h"
#include <chrono>
#include <cstdio>
#include <thread>
#include <unistd.h>
// CGo header
#include "submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go_sfu.h"
int runGroupChurnMode(
int customParticipants,
int referenceParticipants,
int duration,
bool quiet,
bool video,
int churnCycles
) {
gGroupQuiet = quiet;
gGroupStartTime = std::chrono::steady_clock::now();
int baseCount = customParticipants + referenceParticipants;
if (baseCount < 2) {
fprintf(stderr, "Error: need at least 2 base participants total\n");
return 1;
}
groupLog("Churn", "initializing Go SFU...");
int rc = GoSfu_Init();
if (rc != 0) {
fprintf(stderr, "Error: GoSfu_Init failed with %d\n", rc);
return 1;
}
GoInt sfuHandle = GoSfu_Create();
if (sfuHandle <= 0) {
fprintf(stderr, "Error: GoSfu_Create failed\n");
return 1;
}
groupLog("Churn", "SFU handle=%lld, base=%d (custom=%d, ref=%d), cycles=%d, video=%s",
(long long)sfuHandle, baseCount, customParticipants, referenceParticipants,
churnCycles, video ? "yes" : "no");
auto threads = tgcalls::StaticThreads::getThreads();
// --- Phase 1: Create base group ---
groupLog("Churn", "creating base group...");
std::vector<std::unique_ptr<ParticipantState>> baseStates;
bool anyFailed = false;
for (int i = 0; i < baseCount; ++i) {
bool isReference = (i >= customParticipants);
auto state = createParticipant(i, isReference, sfuHandle, threads, quiet, video, &baseStates);
if (!state) {
anyFailed = true;
continue;
}
baseStates.push_back(std::move(state));
}
// Wait for all base participants to connect
groupLog("Churn", "waiting for base group connections...");
auto waitStart = std::chrono::steady_clock::now();
while (std::chrono::steady_clock::now() - waitStart < std::chrono::seconds(15)) {
int connectedCount = 0;
for (const auto& s : baseStates) {
if (s->wasConnected.load()) connectedCount++;
}
if (connectedCount == (int)baseStates.size()) {
groupLog("Churn", "all %d base participants connected", (int)baseStates.size());
break;
}
groupLog("Churn", "base connected: %d/%d", connectedCount, (int)baseStates.size());
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
// Wait for audio to flow in base group
groupLog("Churn", "waiting for base group audio...");
waitStart = std::chrono::steady_clock::now();
while (std::chrono::steady_clock::now() - waitStart < std::chrono::seconds(10)) {
int audioCount = 0;
for (const auto& s : baseStates) {
if (s->receivedAudio.load()) audioCount++;
}
if (audioCount == (int)baseStates.size()) {
groupLog("Churn", "all %d base participants receiving audio", (int)baseStates.size());
break;
}
groupLog("Churn", "base audio: %d/%d", audioCount, (int)baseStates.size());
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
// --- Phase 2: Churn loop ---
groupLog("Churn", "starting churn: %d cycles", churnCycles);
int nextId = baseCount;
int completedCycles = 0;
for (int cycle = 0; cycle < churnCycles; ++cycle) {
bool isReference = (cycle % 2 == 1);
int churnId = nextId++;
auto churner = createParticipant(churnId, isReference, sfuHandle, threads, quiet, video, &baseStates);
if (!churner) {
groupLog("Churn", "cycle %d: createParticipant failed for id=%d", cycle, churnId);
anyFailed = true;
continue;
}
// Wait briefly for connection (up to 3s)
auto connStart = std::chrono::steady_clock::now();
while (std::chrono::steady_clock::now() - connStart < std::chrono::seconds(3)) {
if (churner->wasConnected.load()) break;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
if (!churner->wasConnected.load()) {
groupLog("Churn", "cycle %d: churner %d did not connect (continuing anyway)", cycle, churnId);
}
// Leave
stopParticipant(churner.get(), sfuHandle);
completedCycles++;
if ((cycle + 1) % 10 == 0) {
groupLog("Churn", "progress: %d/%d cycles completed", cycle + 1, churnCycles);
}
}
groupLog("Churn", "churn complete: %d/%d cycles succeeded", completedCycles, churnCycles);
// --- Phase 3: Stabilize and validate ---
groupLog("Churn", "stabilizing for %d seconds...", duration);
std::this_thread::sleep_for(std::chrono::seconds(duration));
auto result = validateGroupState(baseStates, video);
// --- Phase 4: Teardown ---
groupLog("Churn", "stopping base participants...");
// Stop video sources
for (auto& s : baseStates) {
if (s->videoSource) {
s->videoSource->Stop();
}
}
// Stop instances
std::atomic<int> stopCount{0};
std::mutex stopMutex;
std::condition_variable stopCv;
for (const auto& s : baseStates) {
if (s->instance) {
int pid_local = s->id;
s->instance->stop([&stopCount, &stopMutex, &stopCv, pid_local]() {
groupLog("Churn", "base participant %d stopped", pid_local);
stopCount.fetch_add(1);
std::lock_guard<std::mutex> lock(stopMutex);
stopCv.notify_all();
});
}
}
{
std::unique_lock<std::mutex> lock(stopMutex);
stopCv.wait_for(lock, std::chrono::seconds(5), [&] {
return stopCount.load() >= (int)baseStates.size();
});
}
for (auto& s : baseStates) {
s->instance.reset();
}
GoSfu_Destroy(sfuHandle);
GoSfu_Shutdown();
// Print summary
bool success = result.success && !anyFailed && (completedCycles == churnCycles);
printf("\n=== Group Churn Test Summary ===\n");
printf("Base participants: %d (custom=%d, reference=%d)\n",
baseCount, customParticipants, referenceParticipants);
printf("Churn cycles: %d/%d completed\n", completedCycles, churnCycles);
printf("Video: %s\n", video ? "yes" : "no");
printf("Stabilization: %ds\n", duration);
printf("Base connected: %d/%d\n", result.connectedCount, result.totalParticipants);
printf("Base audio received: %d/%d\n", result.audioReceivedCount, result.totalParticipants);
if (video) {
printf("Base video received: %d/%d\n", result.videoReceivedPairs, result.videoExpectedPairs);
}
printf("Result: %s\n", success ? "SUCCESS" : "FAILED");
// Clean up log files
for (const auto& s : baseStates) {
unlink(s->logPath.c_str());
}
fflush(stdout);
fflush(stderr);
_exit(success ? 0 : 1);
}
+10
View File
@@ -0,0 +1,10 @@
#pragma once
int runGroupChurnMode(
int customParticipants,
int referenceParticipants,
int duration,
bool quiet,
bool video,
int churnCycles
);
+173
View File
@@ -0,0 +1,173 @@
#include "group_mode.h"
#include "group_participant.h"
#include <algorithm>
#include <chrono>
#include <cstdio>
#include <string>
#include <thread>
#include <unistd.h>
// CGo header
#include "submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go_sfu.h"
int runGroupMode(int customParticipants, int referenceParticipants, int duration, bool quiet, bool video, const std::string& networkScenario) {
gGroupQuiet = quiet;
gGroupStartTime = std::chrono::steady_clock::now();
int participants = customParticipants + referenceParticipants;
if (participants < 2) {
fprintf(stderr, "Error: need at least 2 participants total\n");
return 1;
}
groupLog("Group", "initializing Go SFU...");
int rc = GoSfu_Init();
if (rc != 0) {
fprintf(stderr, "Error: GoSfu_Init failed with %d\n", rc);
return 1;
}
GoInt sfuHandle = GoSfu_Create();
if (sfuHandle <= 0) {
fprintf(stderr, "Error: GoSfu_Create failed\n");
return 1;
}
groupLog("Group", "created SFU handle=%lld, custom=%d, reference=%d, duration=%ds",
(long long)sfuHandle, customParticipants, referenceParticipants, duration);
auto threads = tgcalls::StaticThreads::getThreads();
// Create participants
std::vector<std::unique_ptr<ParticipantState>> states;
bool anyFailed = false;
for (int i = 0; i < participants; ++i) {
bool isReference = (i >= customParticipants);
auto state = createParticipant(i, isReference, sfuHandle, threads, quiet, video, &states);
if (!state) {
anyFailed = true;
continue;
}
states.push_back(std::move(state));
}
// Wait for all participants to connect
groupLog("Group", "waiting for connections...");
bool allConnected = false;
auto waitStart = std::chrono::steady_clock::now();
while (std::chrono::steady_clock::now() - waitStart < std::chrono::seconds(15)) {
int connectedCount = 0;
for (const auto& s : states) {
if (s->wasConnected.load()) connectedCount++;
}
if (connectedCount == (int)states.size()) {
allConnected = true;
groupLog("Group", "all %d participants connected", (int)states.size());
break;
}
groupLog("Group", "connected: %d/%d", connectedCount, (int)states.size());
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
if (!allConnected) {
int connectedCount = 0;
for (const auto& s : states) {
if (s->wasConnected.load()) connectedCount++;
}
groupLog("Group", "connection timeout: %d/%d connected", connectedCount, (int)states.size());
}
// Run for the specified duration, optionally with network scenario.
if (!networkScenario.empty() && networkScenario == "step-down-up") {
// Scenario: start uncapped, then step down, step up, uncap.
// Split duration into 4 phases.
int phase = std::max(duration / 4, 2);
groupLog("Group", "network-scenario '%s': phase duration=%ds", networkScenario.c_str(), phase);
// Phase 1: uncapped (should be layer 2 on high BW).
groupLog("Group", "phase 1: uncapped");
std::this_thread::sleep_for(std::chrono::seconds(phase));
// Phase 2: cap to 80 kbps (should force downswitch to layer 0).
groupLog("Group", "phase 2: cap 80kbps");
for (const auto& s : states) {
GoSfu_SetNetworkParams(sfuHandle, s->id, 1, 0, 0, 0.0, 80000);
}
std::this_thread::sleep_for(std::chrono::seconds(phase));
// Phase 3: cap to 200 kbps (should allow upswitch to layer 1).
groupLog("Group", "phase 3: cap 200kbps");
for (const auto& s : states) {
GoSfu_SetNetworkParams(sfuHandle, s->id, 1, 0, 0, 0.0, 200000);
}
std::this_thread::sleep_for(std::chrono::seconds(phase));
// Phase 4: uncap (should allow upswitch to layer 2).
groupLog("Group", "phase 4: uncapped");
for (const auto& s : states) {
GoSfu_SetNetworkParams(sfuHandle, s->id, 1, 0, 0, 0.0, 0);
}
std::this_thread::sleep_for(std::chrono::seconds(phase));
} else {
groupLog("Group", "running for %d seconds...", duration);
std::this_thread::sleep_for(std::chrono::seconds(duration));
}
// Stop all participants (using GoSfu_Destroy for bulk teardown)
groupLog("Group", "stopping participants...");
// Stop video sources first
for (auto& s : states) {
if (s->videoSource) {
s->videoSource->Stop();
}
}
// Stop instances
std::atomic<int> stopCount{0};
std::mutex stopMutex;
std::condition_variable stopCv;
for (const auto& s : states) {
if (s->instance) {
int pid_local = s->id;
s->instance->stop([&stopCount, &stopMutex, &stopCv, pid_local]() {
groupLog("Group", "participant %d stopped", pid_local);
stopCount.fetch_add(1);
std::lock_guard<std::mutex> lock(stopMutex);
stopCv.notify_all();
});
}
}
{
std::unique_lock<std::mutex> lock(stopMutex);
stopCv.wait_for(lock, std::chrono::seconds(5), [&] {
return stopCount.load() >= (int)states.size();
});
}
for (auto& s : states) {
s->instance.reset();
}
// Destroy SFU
GoSfu_Destroy(sfuHandle);
GoSfu_Shutdown();
// Validate and print summary
auto result = validateGroupState(states, video);
bool success = printGroupSummary(customParticipants, referenceParticipants, duration, video, result, anyFailed);
// Clean up log files
for (const auto& s : states) {
unlink(s->logPath.c_str());
}
fflush(stdout);
fflush(stderr);
_exit(success ? 0 : 1);
}
+5
View File
@@ -0,0 +1,5 @@
#pragma once
#include <string>
int runGroupMode(int customParticipants, int referenceParticipants, int duration, bool quiet, bool video, const std::string& networkScenario = "");
+442
View File
@@ -0,0 +1,442 @@
#include "group_participant.h"
#include <chrono>
#include <cmath>
#include <condition_variable>
#include <cstdarg>
#include <cstdio>
#include <set>
#include <thread>
#include <unistd.h>
#include "third-party/json11.hpp"
// ---------------------------------------------------------------------------
// Globals
// ---------------------------------------------------------------------------
std::chrono::steady_clock::time_point gGroupStartTime = std::chrono::steady_clock::now();
std::atomic<bool> gGroupQuiet{false};
double groupElapsed() {
auto now = std::chrono::steady_clock::now();
return std::chrono::duration<double>(now - gGroupStartTime).count();
}
void groupLog(const char* tag, const char* fmt, ...) {
if (gGroupQuiet) return;
char buf[512];
va_list ap;
va_start(ap, fmt);
vsnprintf(buf, sizeof(buf), fmt, ap);
va_end(ap);
fprintf(stderr, "[%7.3f] %s: %s\n", groupElapsed(), tag, buf);
}
// ---------------------------------------------------------------------------
// GroupSineRecorder
// ---------------------------------------------------------------------------
GroupSineRecorder::GroupSineRecorder() {
buffer_.resize(kFrameSamples * kChannels);
}
tgcalls::AudioFrame GroupSineRecorder::Record() {
for (size_t i = 0; i < kFrameSamples; ++i) {
double t = static_cast<double>(phase_) / kSampleRate;
int16_t sample = static_cast<int16_t>(kAmplitude * std::sin(2.0 * M_PI * kFrequency * t));
for (size_t ch = 0; ch < kChannels; ++ch) {
buffer_[i * kChannels + ch] = sample;
}
++phase_;
}
tgcalls::AudioFrame frame;
frame.audio_samples = buffer_.data();
frame.num_samples = kFrameSamples;
frame.bytes_per_sample = sizeof(int16_t);
frame.num_channels = kChannels;
frame.samples_per_sec = kSampleRate;
frame.elapsed_time_ms = 0;
frame.ntp_time_ms = 0;
return frame;
}
int32_t GroupSineRecorder::WaitForUs() {
return 10000; // 10ms
}
// ---------------------------------------------------------------------------
// GroupNoOpRenderer
// ---------------------------------------------------------------------------
bool GroupNoOpRenderer::Render(const tgcalls::AudioFrame&) { return true; }
// ---------------------------------------------------------------------------
// SimpleRequestMediaChannelDescriptionTask
// ---------------------------------------------------------------------------
void SimpleRequestMediaChannelDescriptionTask::cancel() {}
// ---------------------------------------------------------------------------
// createParticipant
// ---------------------------------------------------------------------------
std::unique_ptr<ParticipantState> createParticipant(
int id,
bool isReference,
GoInt sfuHandle,
std::shared_ptr<tgcalls::Threads> threads,
bool quiet,
bool video,
std::vector<std::unique_ptr<ParticipantState>>* allStates
) {
auto state = std::make_unique<ParticipantState>();
state->id = id;
state->isReference = isReference;
state->logPath = "/tmp/tgcalls_group_p" + std::to_string(id) + "_" + std::to_string(getpid()) + ".log";
std::string tag = "P" + std::to_string(id);
auto recorder = std::make_shared<GroupSineRecorder>();
auto renderer = std::make_shared<GroupNoOpRenderer>();
ParticipantState* statePtr = state.get();
GoInt sfuH = sfuHandle;
tgcalls::GroupInstanceDescriptor descriptor;
descriptor.threads = threads;
descriptor.config.need_log = true;
descriptor.config.logPath = {state->logPath};
descriptor.networkStateUpdated = [statePtr, tag](tgcalls::GroupNetworkState networkState) {
groupLog(tag.c_str(), "network state: connected=%s", networkState.isConnected ? "true" : "false");
statePtr->connected.store(networkState.isConnected);
if (networkState.isConnected) {
statePtr->wasConnected.store(true);
}
};
descriptor.audioLevelsUpdated = [statePtr, tag](tgcalls::GroupLevelsUpdate const &update) {
for (const auto& level : update.updates) {
if (level.value.level > 0.01f) {
groupLog(tag.c_str(), "audio level: ssrc=%u level=%.3f voice=%d",
level.ssrc, level.value.level, level.value.voice);
}
if (level.ssrc != 0 && level.value.level > 0.05f) {
statePtr->receivedAudio.store(true);
}
}
};
descriptor.createAudioDeviceModule = tgcalls::FakeAudioDeviceModule::Creator(
renderer, recorder,
tgcalls::FakeAudioDeviceModule::Options{.samples_per_sec = 48000, .num_channels = 2}
);
descriptor.requestMediaChannelDescriptions = [sfuH, tag, allStates](
std::vector<uint32_t> const &ssrcs,
std::function<void(std::vector<tgcalls::MediaChannelDescription> &&)> callback
) -> std::shared_ptr<tgcalls::RequestMediaChannelDescriptionTask> {
std::set<uint32_t> audioSsrcs;
for (const auto& s : *allStates) {
if (s->audioSsrc != 0) audioSsrcs.insert(s->audioSsrc);
}
std::vector<tgcalls::MediaChannelDescription> descriptions;
for (uint32_t ssrc : ssrcs) {
GoInt ownerID = GoSfu_QuerySsrc(sfuH, (GoUint)ssrc);
bool isAudio = audioSsrcs.count(ssrc) > 0;
groupLog(tag.c_str(), "requestMediaChannelDescriptions: ssrc=%u -> owner=%lld type=%s",
ssrc, (long long)ownerID, isAudio ? "audio" : "video");
tgcalls::MediaChannelDescription desc;
desc.type = isAudio ? tgcalls::MediaChannelDescription::Type::Audio
: tgcalls::MediaChannelDescription::Type::Video;
desc.audioSsrc = ssrc;
desc.userId = ownerID;
descriptions.push_back(std::move(desc));
}
callback(std::move(descriptions));
return std::make_shared<SimpleRequestMediaChannelDescriptionTask>();
};
descriptor.outgoingAudioBitrateKbit = 32;
descriptor.disableIncomingChannels = false;
descriptor.useDummyChannel = true;
// Video configuration
if (video) {
auto videoSource = FakeVideoTrackSource::Create(id);
state->videoSource = videoSource;
state->endpointId = std::to_string(id);
descriptor.videoContentType = tgcalls::VideoContentType::Generic;
descriptor.videoCodecPreferences = {tgcalls::VideoCodecName::H264};
// Set the outgoing video min bitrate to 600 kbps so the sender's
// BWE floor is high enough to activate all 3 simulcast layers
// (audio 32k + L0 min 50k + L1 min 100k + L2 min 300k = 482k).
// On localhost, delay-based BWE over the loopback pacer has been
// observed to drift down to ~80 kbps, keeping L2 disabled. Clamping
// the min forces the encoder to keep L2 producing.
descriptor.minOutgoingVideoBitrateKbit = 600;
descriptor.getVideoSource = [videoSource]() -> webrtc::scoped_refptr<webrtc::VideoTrackSourceInterface> {
return videoSource;
};
descriptor.dataChannelMessageReceived = [statePtr, sfuH, tag](std::string const &message) {
std::string parseErr;
auto json = json11::Json::parse(message, parseErr);
if (!parseErr.empty() || !json.is_object()) return;
auto cls = json["colibriClass"].string_value();
if (cls != "ActiveVideoSsrcs") return;
auto ssrcsArray = json["ssrcs"].array_items();
if (ssrcsArray.empty()) return;
std::vector<tgcalls::VideoChannelDescription> videoChannels;
for (const auto& entry : ssrcsArray) {
std::string endpointId = entry["endpointId"].string_value();
if (endpointId == statePtr->endpointId) continue;
{
std::lock_guard<std::mutex> lock(statePtr->videoSinksMutex);
if (statePtr->videoSinks.count(endpointId) > 0) continue;
}
int remoteId = 0;
if (sscanf(endpointId.c_str(), "%d", &remoteId) != 1) continue;
char* ssrcsRaw = GoSfu_QueryVideoSsrcs(sfuH, (GoInt)remoteId);
if (!ssrcsRaw) continue;
std::string ssrcsJson(ssrcsRaw);
GoSfu_Free(ssrcsRaw);
std::string err2;
auto layers = json11::Json::parse(ssrcsJson, err2);
if (!err2.empty() || !layers.is_array() || layers.array_items().empty()) continue;
tgcalls::VideoChannelDescription desc;
desc.audioSsrc = 0;
desc.userId = remoteId;
desc.endpointId = endpointId;
desc.maxQuality = tgcalls::VideoChannelDescription::Quality::Full;
desc.minQuality = tgcalls::VideoChannelDescription::Quality::Full;
tgcalls::MediaSsrcGroup simGroup;
simGroup.semantics = "SIM";
for (const auto& layer : layers.array_items()) {
uint32_t ssrc = static_cast<uint32_t>(static_cast<int64_t>(layer["ssrc"].number_value()));
uint32_t fidSsrc = static_cast<uint32_t>(static_cast<int64_t>(layer["fidSsrc"].number_value()));
if (ssrc == 0) continue;
simGroup.ssrcs.push_back(ssrc);
if (fidSsrc != 0) {
tgcalls::MediaSsrcGroup fidGroup;
fidGroup.semantics = "FID";
fidGroup.ssrcs = {ssrc, fidSsrc};
desc.ssrcGroups.push_back(std::move(fidGroup));
}
}
desc.ssrcGroups.insert(desc.ssrcGroups.begin(), std::move(simGroup));
videoChannels.push_back(std::move(desc));
auto sink = std::make_shared<FakeVideoSink>();
{
std::lock_guard<std::mutex> lock(statePtr->videoSinksMutex);
statePtr->videoSinks[endpointId] = sink;
}
statePtr->instance->addIncomingVideoOutput(
endpointId,
std::weak_ptr<rtc::VideoSinkInterface<webrtc::VideoFrame>>(sink));
groupLog(tag.c_str(), "ActiveVideoSsrcs: adding video channel for endpoint %s", endpointId.c_str());
}
if (!videoChannels.empty()) {
statePtr->instance->setRequestedVideoChannels(std::move(videoChannels));
}
};
} else {
descriptor.videoContentType = tgcalls::VideoContentType::None;
}
// Create instance
if (isReference) {
state->instance = std::make_unique<tgcalls::GroupInstanceReferenceImpl>(std::move(descriptor));
groupLog(tag.c_str(), "created GroupInstanceReferenceImpl");
} else {
state->instance = std::make_unique<tgcalls::GroupInstanceCustomImpl>(std::move(descriptor));
groupLog(tag.c_str(), "created GroupInstanceCustomImpl");
}
// Set connection mode
state->instance->setConnectionMode(
tgcalls::GroupConnectionMode::GroupConnectionModeRtc, false, false);
// Emit join payload
std::mutex joinMutex;
std::condition_variable joinCv;
bool joinReady = false;
std::string joinJson;
uint32_t joinSsrc = 0;
state->instance->emitJoinPayload([&](tgcalls::GroupJoinPayload const &payload) {
std::lock_guard<std::mutex> lock(joinMutex);
joinJson = payload.json;
joinSsrc = payload.audioSsrc;
joinReady = true;
joinCv.notify_one();
});
{
std::unique_lock<std::mutex> lock(joinMutex);
if (!joinCv.wait_for(lock, std::chrono::seconds(5), [&] { return joinReady; })) {
fprintf(stderr, "Error: emitJoinPayload timed out for participant %d\n", id);
return nullptr;
}
}
state->audioSsrc = joinSsrc;
groupLog(tag.c_str(), "join payload ready: ssrc=%u, json=%zu bytes", joinSsrc, joinJson.size());
// Join SFU
GoInt iceControlling = isReference ? 0 : 1;
char* responseRaw = GoSfu_Join(sfuHandle, (GoInt)id, const_cast<char*>(joinJson.c_str()), iceControlling);
if (!responseRaw) {
fprintf(stderr, "Error: GoSfu_Join returned null for participant %d\n", id);
return nullptr;
}
std::string response(responseRaw);
GoSfu_Free(responseRaw);
if (response.find("\"error\"") != std::string::npos) {
fprintf(stderr, "Error: GoSfu_Join failed for participant %d: %s\n", id, response.c_str());
return nullptr;
}
groupLog(tag.c_str(), "SFU join response: %zu bytes", response.size());
state->instance->setJoinResponsePayload(response);
state->instance->setIsMuted(false);
groupLog(tag.c_str(), "joined and unmuted");
return state;
}
// ---------------------------------------------------------------------------
// stopParticipant
// ---------------------------------------------------------------------------
void stopParticipant(ParticipantState* state, GoInt sfuHandle) {
if (!state || !state->instance) return;
std::string tag = "P" + std::to_string(state->id);
// Remove from SFU first so broadcasts go out to remaining participants.
GoInt rc = GoSfu_Leave(sfuHandle, (GoInt)state->id);
if (rc != 0) {
groupLog(tag.c_str(), "GoSfu_Leave returned %lld (may already be removed)", (long long)rc);
}
// Stop video source.
if (state->videoSource) {
state->videoSource->Stop();
}
// Stop instance with timeout. Heap-allocate sync state so the stop callback
// is safe even if it fires after the 5s timeout (avoids stack-frame UB).
struct StopState {
std::mutex mu;
std::condition_variable cv;
std::atomic<bool> done{false};
};
auto stopState = std::make_shared<StopState>();
state->instance->stop([stopState]() {
stopState->done.store(true);
std::lock_guard<std::mutex> lock(stopState->mu);
stopState->cv.notify_all();
});
{
std::unique_lock<std::mutex> lock(stopState->mu);
stopState->cv.wait_for(lock, std::chrono::seconds(5), [&] { return stopState->done.load(); });
}
state->instance.reset();
// Clean up log file.
unlink(state->logPath.c_str());
groupLog(tag.c_str(), "stopped and cleaned up");
}
// ---------------------------------------------------------------------------
// validateGroupState
// ---------------------------------------------------------------------------
GroupValidationResult validateGroupState(
const std::vector<std::unique_ptr<ParticipantState>>& states,
bool video
) {
GroupValidationResult result{};
result.totalParticipants = static_cast<int>(states.size());
for (const auto& s : states) {
if (s->wasConnected.load()) result.connectedCount++;
if (s->receivedAudio.load()) result.audioReceivedCount++;
}
if (video) {
int videoParticipants = 0;
for (const auto& s : states) {
if (s->videoSource) videoParticipants++;
}
result.videoExpectedPairs = videoParticipants * (videoParticipants - 1);
for (const auto& s : states) {
std::lock_guard<std::mutex> lock(s->videoSinksMutex);
for (const auto& [endpointId, sink] : s->videoSinks) {
int frames = sink->frameCount();
if (frames > 0) {
result.videoReceivedPairs++;
}
groupLog("Validate", "P%d <- endpoint %s: %d video frames (%dx%d)",
s->id, endpointId.c_str(), frames,
sink->lastWidth(), sink->lastHeight());
}
}
}
result.success = (result.connectedCount == result.totalParticipants &&
result.audioReceivedCount == result.totalParticipants);
if (video && result.videoExpectedPairs > 0) {
result.success = result.success && (result.videoReceivedPairs >= result.videoExpectedPairs);
}
return result;
}
// ---------------------------------------------------------------------------
// printGroupSummary
// ---------------------------------------------------------------------------
bool printGroupSummary(
int customParticipants,
int referenceParticipants,
int duration,
bool video,
const GroupValidationResult& result,
bool anyFailed
) {
bool success = result.success && !anyFailed;
printf("\n=== Group Call Summary ===\n");
printf("Custom participants: %d\n", customParticipants);
printf("Reference participants: %d\n", referenceParticipants);
printf("Total participants: %d\n", result.totalParticipants);
printf("Duration: %ds\n", duration);
printf("SFU: Go/Pion (in-process)\n");
printf("Connected: %d/%d\n", result.connectedCount, result.totalParticipants);
printf("Audio received: %d/%d\n", result.audioReceivedCount, result.totalParticipants);
if (video) {
printf("Video received: %d/%d\n", result.videoReceivedPairs, result.videoExpectedPairs);
}
printf("Result: %s\n", success ? "SUCCESS" : "FAILED");
return success;
}
+143
View File
@@ -0,0 +1,143 @@
#pragma once
#include <atomic>
#include <chrono>
#include <cstdarg>
#include <cstdio>
#include <functional>
#include <map>
#include <memory>
#include <mutex>
#include <set>
#include <string>
#include <vector>
#include "group/GroupInstanceCustomImpl.h"
#include "group/GroupInstanceImpl.h"
#include "group/GroupInstanceReferenceImpl.h"
#include "FakeAudioDeviceModule.h"
#include "StaticThreads.h"
#include "AudioFrame.h"
#include "fake_video_source.h"
#include "fake_video_sink.h"
// CGo header
#include "submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go_sfu.h"
// ---------------------------------------------------------------------------
// Logging helpers
// ---------------------------------------------------------------------------
extern std::chrono::steady_clock::time_point gGroupStartTime;
extern std::atomic<bool> gGroupQuiet;
double groupElapsed();
void groupLog(const char* tag, const char* fmt, ...);
// ---------------------------------------------------------------------------
// GroupSineRecorder - generates 440 Hz sine tone
// ---------------------------------------------------------------------------
class GroupSineRecorder : public tgcalls::FakeAudioDeviceModule::Recorder {
public:
GroupSineRecorder();
tgcalls::AudioFrame Record() override;
int32_t WaitForUs() override;
private:
static constexpr size_t kSampleRate = 48000;
static constexpr size_t kChannels = 2;
static constexpr size_t kFrameSamples = 480;
static constexpr double kFrequency = 440.0;
static constexpr double kAmplitude = 3000.0;
std::vector<int16_t> buffer_;
uint64_t phase_ = 0;
};
// ---------------------------------------------------------------------------
// GroupNoOpRenderer - discards received audio
// ---------------------------------------------------------------------------
class GroupNoOpRenderer : public tgcalls::FakeAudioDeviceModule::Renderer {
public:
bool Render(const tgcalls::AudioFrame&) override;
};
// ---------------------------------------------------------------------------
// SimpleRequestMediaChannelDescriptionTask
// ---------------------------------------------------------------------------
class SimpleRequestMediaChannelDescriptionTask : public tgcalls::RequestMediaChannelDescriptionTask {
public:
void cancel() override;
};
// ---------------------------------------------------------------------------
// ParticipantState
// ---------------------------------------------------------------------------
struct ParticipantState {
int id;
bool isReference;
std::unique_ptr<tgcalls::GroupInstanceInterface> instance;
std::atomic<bool> connected{false};
std::atomic<bool> wasConnected{false};
std::atomic<bool> receivedAudio{false};
uint32_t audioSsrc{0};
std::string logPath;
// Video fields
std::string endpointId;
rtc::scoped_refptr<FakeVideoTrackSource> videoSource;
std::mutex videoSinksMutex;
std::map<std::string, std::shared_ptr<FakeVideoSink>> videoSinks;
};
// ---------------------------------------------------------------------------
// GroupValidationResult
// ---------------------------------------------------------------------------
struct GroupValidationResult {
int totalParticipants;
int connectedCount;
int audioReceivedCount;
int videoReceivedPairs;
int videoExpectedPairs;
bool success;
};
// ---------------------------------------------------------------------------
// Participant lifecycle functions
// ---------------------------------------------------------------------------
// Creates a fully initialized participant: builds descriptor, creates instance,
// joins SFU, sets join response, unmutes. Returns nullptr on failure.
std::unique_ptr<ParticipantState> createParticipant(
int id,
bool isReference,
GoInt sfuHandle,
std::shared_ptr<tgcalls::Threads> threads,
bool quiet,
bool video,
std::vector<std::unique_ptr<ParticipantState>>* allStates
);
// Clean teardown: GoSfu_Leave, stop video, stop instance, reset.
void stopParticipant(ParticipantState* state, GoInt sfuHandle);
// Validates group state: connection, audio, video. Returns result struct.
GroupValidationResult validateGroupState(
const std::vector<std::unique_ptr<ParticipantState>>& states,
bool video
);
// Prints a group call summary to stdout. Returns the success boolean.
bool printGroupSummary(
int customParticipants,
int referenceParticipants,
int duration,
bool video,
const GroupValidationResult& result,
bool anyFailed
);
+602
View File
@@ -0,0 +1,602 @@
#include <array>
#include <atomic>
#include <chrono>
#include <cmath>
#include <condition_variable>
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <memory>
#include <mutex>
#include <random>
#include <fstream>
#include <string>
#include <thread>
#include <unistd.h>
#include <vector>
#include "group_mode.h"
#include "group_churn_mode.h"
#include "Instance.h"
#include "FakeAudioDeviceModule.h"
#include "VideoCaptureInterface.h"
#include "v2/InstanceV2Impl.h"
#include "v2/InstanceV2CompatImpl.h"
#include "v2/InstanceV2ReferenceImpl.h"
#include "modules/audio_device/include/audio_device.h"
#include "api/task_queue/task_queue_factory.h"
// Stub: AudioDeviceModule::Create is referenced by InstanceV2Impl as a fallback
// but never called when createAudioDeviceModule is provided in the Descriptor.
namespace webrtc {
rtc::scoped_refptr<AudioDeviceModule> AudioDeviceModule::Create(
AudioDeviceModule::AudioLayer audio_layer,
TaskQueueFactory* task_queue_factory) {
return nullptr;
}
} // namespace webrtc
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
static auto gStartTime = std::chrono::steady_clock::now();
static double elapsed() {
auto now = std::chrono::steady_clock::now();
return std::chrono::duration<double>(now - gStartTime).count();
}
static bool gQuiet = false;
static void logMsg(const char* role, const char* fmt, ...) {
if (gQuiet) return;
char buf[512];
va_list ap;
va_start(ap, fmt);
vsnprintf(buf, sizeof(buf), fmt, ap);
va_end(ap);
fprintf(stderr, "[%7.3f] %s: %s\n", elapsed(), role, buf);
}
static const char* stateName(tgcalls::State s) {
switch (s) {
case tgcalls::State::WaitInit: return "WaitInit";
case tgcalls::State::WaitInitAck: return "WaitInitAck";
case tgcalls::State::Established: return "Established";
case tgcalls::State::Failed: return "Failed";
case tgcalls::State::Reconnecting:return "Reconnecting";
}
return "Unknown";
}
static std::string hexEncode(const std::array<uint8_t, 16>& data) {
char buf[33];
for (size_t i = 0; i < 16; ++i) {
snprintf(buf + i * 2, 3, "%02x", data[i]);
}
return std::string(buf, 32);
}
static tgcalls::RtcServer makeReflectorServer(const std::string& host, uint16_t port,
const std::array<uint8_t, 16>& peerTag) {
tgcalls::RtcServer server;
server.id = 1;
server.host = host;
server.port = port;
server.login = "reflector";
server.password = hexEncode(peerTag);
server.isTurn = true;
server.isTcp = false;
return server;
}
// ---------------------------------------------------------------------------
// SineRecorder - generates 440 Hz sine tone
// ---------------------------------------------------------------------------
class SineRecorder : public tgcalls::FakeAudioDeviceModule::Recorder {
public:
SineRecorder() {
buffer_.resize(kFrameSamples * kChannels);
}
tgcalls::AudioFrame Record() override {
for (size_t i = 0; i < kFrameSamples; ++i) {
double t = static_cast<double>(phase_) / kSampleRate;
int16_t sample = static_cast<int16_t>(kAmplitude * std::sin(2.0 * M_PI * kFrequency * t));
for (size_t ch = 0; ch < kChannels; ++ch) {
buffer_[i * kChannels + ch] = sample;
}
++phase_;
}
tgcalls::AudioFrame frame;
frame.audio_samples = buffer_.data();
frame.num_samples = kFrameSamples;
frame.bytes_per_sample = sizeof(int16_t);
frame.num_channels = kChannels;
frame.samples_per_sec = kSampleRate;
frame.elapsed_time_ms = 0;
frame.ntp_time_ms = 0;
return frame;
}
int32_t WaitForUs() override {
return 10000; // 10ms
}
private:
static constexpr size_t kSampleRate = 48000;
static constexpr size_t kChannels = 2;
static constexpr size_t kFrameSamples = 480; // 10ms at 48kHz
static constexpr double kFrequency = 440.0;
static constexpr double kAmplitude = 3000.0;
std::vector<int16_t> buffer_;
uint64_t phase_ = 0;
};
// ---------------------------------------------------------------------------
// NoOpRenderer - discards received audio (validation is done via BWE stats)
// ---------------------------------------------------------------------------
class NoOpRenderer : public tgcalls::FakeAudioDeviceModule::Renderer {
public:
bool Render(const tgcalls::AudioFrame&) override { return true; }
};
// ---------------------------------------------------------------------------
// SignalingBridge
// ---------------------------------------------------------------------------
struct SignalingBridge {
std::mutex mutex;
std::shared_ptr<tgcalls::Instance> caller;
std::shared_ptr<tgcalls::Instance> callee;
// Network simulation
double dropRate = 0.0;
int delayMinMs = 0;
int delayMaxMs = 0;
std::mt19937 rng{std::random_device{}()};
void deliver(const char* fromRole, const std::vector<uint8_t>& data,
std::shared_ptr<tgcalls::Instance>& target) {
if (dropRate > 0.0) {
std::uniform_real_distribution<double> dropDist(0.0, 1.0);
if (dropDist(rng) < dropRate) {
logMsg(fromRole, "signaling DROPPED (%zu bytes)", data.size());
return;
}
}
if (delayMaxMs > 0) {
std::uniform_int_distribution<int> delayDist(delayMinMs, delayMaxMs);
int delayMs = delayDist(rng);
if (delayMs > 0) {
logMsg(fromRole, "signaling delayed %dms (%zu bytes)", delayMs, data.size());
auto dataCopy = data;
auto targetWeak = std::weak_ptr<tgcalls::Instance>(target);
std::thread([dataCopy, targetWeak, delayMs]() {
std::this_thread::sleep_for(std::chrono::milliseconds(delayMs));
if (auto t = targetWeak.lock()) {
t->receiveSignalingData(dataCopy);
}
}).detach();
return;
}
}
if (target) {
target->receiveSignalingData(data);
}
}
};
// ---------------------------------------------------------------------------
// CallState
// ---------------------------------------------------------------------------
struct CallState {
std::mutex mutex;
tgcalls::State callerState = tgcalls::State::WaitInit;
tgcalls::State calleeState = tgcalls::State::WaitInit;
double establishedAt = -1.0;
std::vector<std::string> errors;
};
// ---------------------------------------------------------------------------
// main
// ---------------------------------------------------------------------------
int main(int argc, char* argv[]) {
int duration = 10;
std::string mode;
std::string reflectorAddr;
std::string reflectorList;
std::string version = "13.0.0";
std::string version2;
double dropRate = 0.0;
int delayMinMs = 0;
int delayMaxMs = 0;
int participants = 3;
int referenceParticipants = 0;
bool enableVideo = false;
int churnCycles = 100;
std::string networkScenario;
for (int i = 1; i < argc; ++i) {
if (std::string(argv[i]) == "--duration" && i + 1 < argc) {
duration = std::atoi(argv[++i]);
} else if (std::string(argv[i]) == "--quiet") {
gQuiet = true;
} else if (std::string(argv[i]) == "--mode" && i + 1 < argc) {
mode = argv[++i];
} else if (std::string(argv[i]) == "--reflector" && i + 1 < argc) {
reflectorAddr = argv[++i];
} else if (std::string(argv[i]) == "--reflector-list" && i + 1 < argc) {
reflectorList = argv[++i];
} else if (std::string(argv[i]) == "--drop-rate" && i + 1 < argc) {
dropRate = std::atof(argv[++i]);
} else if (std::string(argv[i]) == "--version" && i + 1 < argc) {
version = argv[++i];
} else if (std::string(argv[i]) == "--version2" && i + 1 < argc) {
version2 = argv[++i];
} else if (std::string(argv[i]) == "--participants" && i + 1 < argc) {
participants = std::atoi(argv[++i]);
} else if (std::string(argv[i]) == "--reference-participants" && i + 1 < argc) {
referenceParticipants = std::atoi(argv[++i]);
} else if (std::string(argv[i]) == "--video") {
enableVideo = true;
} else if (std::string(argv[i]) == "--churn-cycles" && i + 1 < argc) {
churnCycles = std::atoi(argv[++i]);
} else if (std::string(argv[i]) == "--network-scenario" && i + 1 < argc) {
networkScenario = argv[++i];
} else if (std::string(argv[i]) == "--delay" && i + 1 < argc) {
std::string delayStr = argv[++i];
auto dashPos = delayStr.find('-');
if (dashPos != std::string::npos) {
delayMinMs = std::atoi(delayStr.substr(0, dashPos).c_str());
delayMaxMs = std::atoi(delayStr.substr(dashPos + 1).c_str());
} else {
delayMinMs = 0;
delayMaxMs = std::atoi(delayStr.c_str());
}
}
}
if (version2.empty()) {
version2 = version;
}
// If --reflector-list provided, pick one at random
if (!reflectorList.empty()) {
std::vector<std::string> addrs;
size_t pos = 0;
while (pos < reflectorList.size()) {
size_t next = reflectorList.find(',', pos);
if (next == std::string::npos) next = reflectorList.size();
std::string addr = reflectorList.substr(pos, next - pos);
if (!addr.empty()) addrs.push_back(addr);
pos = next + 1;
}
if (addrs.empty()) {
fprintf(stderr, "Error: --reflector-list is empty\n");
return 1;
}
std::random_device rd;
std::mt19937 rng(rd());
std::uniform_int_distribution<size_t> dist(0, addrs.size() - 1);
reflectorAddr = addrs[dist(rng)];
if (reflectorAddr.rfind(':') == std::string::npos) {
std::uniform_int_distribution<int> portDist(596, 599);
reflectorAddr += ":" + std::to_string(portDist(rng));
}
if (mode.empty()) mode = "reflector";
}
// Validate --mode
if (mode.empty()) {
fprintf(stderr, "Error: --mode is required (p2p, reflector, group, or group-churn)\n");
return 1;
}
if (mode != "p2p" && mode != "reflector" && mode != "group" && mode != "group-churn") {
fprintf(stderr, "Error: --mode must be 'p2p', 'reflector', 'group', or 'group-churn'\n");
return 1;
}
// Group mode: dispatch to separate implementation
if (mode == "group") {
return runGroupMode(participants, referenceParticipants, duration, gQuiet, enableVideo, networkScenario);
}
if (mode == "group-churn") {
return runGroupChurnMode(participants, referenceParticipants, duration, gQuiet, enableVideo, churnCycles);
}
if (mode == "reflector" && reflectorAddr.empty()) {
fprintf(stderr, "Error: --reflector host:port is required with --mode reflector\n");
return 1;
}
if (mode == "p2p" && !reflectorAddr.empty()) {
fprintf(stderr, "Error: --reflector cannot be used with --mode p2p\n");
return 1;
}
// Parse reflector address
std::string reflectorHost;
uint16_t reflectorPort = 0;
if (mode == "reflector") {
auto colonPos = reflectorAddr.rfind(':');
if (colonPos == std::string::npos) {
fprintf(stderr, "Error: --reflector must be in host:port format\n");
return 1;
}
reflectorHost = reflectorAddr.substr(0, colonPos);
reflectorPort = static_cast<uint16_t>(std::atoi(reflectorAddr.substr(colonPos + 1).c_str()));
if (reflectorPort == 0) {
fprintf(stderr, "Error: invalid reflector port\n");
return 1;
}
}
// Generate peer tags for reflector mode
std::array<uint8_t, 16> callerPeerTag{};
std::array<uint8_t, 16> calleePeerTag{};
if (mode == "reflector") {
std::random_device rd;
std::mt19937 rng(rd());
std::uniform_int_distribution<int> dist(0, 255);
for (auto& b : callerPeerTag) {
b = static_cast<uint8_t>(dist(rng));
}
calleePeerTag = callerPeerTag;
callerPeerTag[0] = 0x00;
calleePeerTag[0] = 0x01;
}
// Register implementations
tgcalls::Register<tgcalls::InstanceV2Impl>();
tgcalls::Register<tgcalls::InstanceV2CompatImpl>();
tgcalls::Register<tgcalls::InstanceV2ReferenceImpl>();
// Create shared encryption key
auto keyData = std::make_shared<std::array<uint8_t, 256>>();
{
std::mt19937 rng(42);
std::uniform_int_distribution<int> dist(0, 255);
for (auto& b : *keyData) {
b = static_cast<uint8_t>(dist(rng));
}
}
// Bridge and state
auto bridge = std::make_shared<SignalingBridge>();
bridge->dropRate = dropRate;
bridge->delayMinMs = delayMinMs;
bridge->delayMaxMs = delayMaxMs;
auto callState = std::make_shared<CallState>();
// Audio components
auto callerRecorder = std::make_shared<SineRecorder>();
auto callerRenderer = std::make_shared<NoOpRenderer>();
auto calleeRecorder = std::make_shared<SineRecorder>();
auto calleeRenderer = std::make_shared<NoOpRenderer>();
// Stats log paths (per-process to avoid collisions in parallel runs)
std::string callerStatsPath = "/tmp/tgcalls_cli_caller_" + std::to_string(getpid()) + ".json";
std::string calleeStatsPath = "/tmp/tgcalls_cli_callee_" + std::to_string(getpid()) + ".json";
// --- Caller descriptor ---
auto callerDesc = (tgcalls::Descriptor){
.version = version,
.config = {
.initializationTimeout = 10.0,
.receiveTimeout = 10.0,
.enableP2P = (mode == "p2p"),
.statsLogPath = {callerStatsPath},
},
.rtcServers = (mode == "reflector")
? std::vector<tgcalls::RtcServer>{makeReflectorServer(reflectorHost, reflectorPort, callerPeerTag)}
: std::vector<tgcalls::RtcServer>{},
.encryptionKey = tgcalls::EncryptionKey(keyData, true),
.stateUpdated = [callState](tgcalls::State state) {
logMsg("Caller", "state -> %s", stateName(state));
std::lock_guard<std::mutex> lock(callState->mutex);
callState->callerState = state;
if (state == tgcalls::State::Established && callState->establishedAt < 0) {
callState->establishedAt = elapsed();
}
if (state == tgcalls::State::Failed) {
callState->errors.push_back("Caller entered Failed state");
}
},
.signalingDataEmitted = [bridge](const std::vector<uint8_t>& data) {
logMsg("Caller", "signaling data emitted (%zu bytes)", data.size());
std::lock_guard<std::mutex> lock(bridge->mutex);
bridge->deliver("Caller", data, bridge->callee);
},
.createAudioDeviceModule = tgcalls::FakeAudioDeviceModule::Creator(
callerRenderer, callerRecorder,
tgcalls::FakeAudioDeviceModule::Options{.samples_per_sec = 48000, .num_channels = 2}
),
};
// --- Callee descriptor ---
auto calleeDesc = (tgcalls::Descriptor){
.version = version2,
.config = {
.initializationTimeout = 10.0,
.receiveTimeout = 10.0,
.enableP2P = (mode == "p2p"),
.statsLogPath = {calleeStatsPath},
},
.rtcServers = (mode == "reflector")
? std::vector<tgcalls::RtcServer>{makeReflectorServer(reflectorHost, reflectorPort, calleePeerTag)}
: std::vector<tgcalls::RtcServer>{},
.encryptionKey = tgcalls::EncryptionKey(keyData, false),
.stateUpdated = [callState](tgcalls::State state) {
logMsg("Callee", "state -> %s", stateName(state));
std::lock_guard<std::mutex> lock(callState->mutex);
callState->calleeState = state;
if (state == tgcalls::State::Established && callState->establishedAt < 0) {
callState->establishedAt = elapsed();
}
if (state == tgcalls::State::Failed) {
callState->errors.push_back("Callee entered Failed state");
}
},
.signalingDataEmitted = [bridge](const std::vector<uint8_t>& data) {
logMsg("Callee", "signaling data emitted (%zu bytes)", data.size());
std::lock_guard<std::mutex> lock(bridge->mutex);
bridge->deliver("Callee", data, bridge->caller);
},
.createAudioDeviceModule = tgcalls::FakeAudioDeviceModule::Creator(
calleeRenderer, calleeRecorder,
tgcalls::FakeAudioDeviceModule::Options{.samples_per_sec = 48000, .num_channels = 2}
),
};
// Create instances
auto callerInstance = std::shared_ptr<tgcalls::Instance>(
tgcalls::Meta::Create(version, std::move(callerDesc)).release());
if (!callerInstance) {
fprintf(stderr, "Error: unknown version '%s'\n", version.c_str());
return 1;
}
logMsg("Caller", "created (version %s)", version.c_str());
auto calleeInstance = std::shared_ptr<tgcalls::Instance>(
tgcalls::Meta::Create(version2, std::move(calleeDesc)).release());
if (!calleeInstance) {
fprintf(stderr, "Error: unknown callee version '%s'\n", version2.c_str());
return 1;
}
logMsg("Callee", "created (version %s)", version2.c_str());
// Wire bridge
{
std::lock_guard<std::mutex> lock(bridge->mutex);
bridge->caller = callerInstance;
bridge->callee = calleeInstance;
}
logMsg("Main", "sleeping for %d seconds...", duration);
std::this_thread::sleep_for(std::chrono::seconds(duration));
// Stop both instances
logMsg("Main", "stopping instances...");
std::atomic<int> stopCount{0};
std::mutex stopMutex;
std::condition_variable stopCv;
auto onStopped = [&](const char* role) {
return [&, role](tgcalls::FinalState) {
logMsg(role, "stopped");
stopCount.fetch_add(1);
std::lock_guard<std::mutex> lock(stopMutex);
stopCv.notify_all();
};
};
callerInstance->stop(onStopped("Caller"));
calleeInstance->stop(onStopped("Callee"));
// Wait for both stop callbacks (up to 5 seconds)
{
std::unique_lock<std::mutex> lock(stopMutex);
stopCv.wait_for(lock, std::chrono::seconds(5), [&] {
return stopCount.load() >= 2;
});
}
// Release instances — clear bridge first to prevent signaling during teardown
{
std::lock_guard<std::mutex> lock(bridge->mutex);
bridge->caller.reset();
bridge->callee.reset();
}
callerInstance.reset();
calleeInstance.reset();
// Read stats logs: count bitrate records and check for non-zero BWE
struct StatsResult {
int bitrateRecords = 0;
bool hasNonZeroBwe = false;
};
auto parseStatsLog = [](const std::string& path) -> StatsResult {
StatsResult result;
std::ifstream f(path);
if (!f.is_open()) return result;
std::string content((std::istreambuf_iterator<char>(f)),
std::istreambuf_iterator<char>());
size_t pos = 0;
while ((pos = content.find("\"b\":", pos)) != std::string::npos) {
pos += 4;
result.bitrateRecords++;
// Parse the integer value after "b":
int val = std::atoi(content.c_str() + pos);
if (val > 0) {
result.hasNonZeroBwe = true;
}
}
return result;
};
auto callerStats = parseStatsLog(callerStatsPath);
auto calleeStats = parseStatsLog(calleeStatsPath);
unlink(callerStatsPath.c_str());
unlink(calleeStatsPath.c_str());
// Print summary
{
std::lock_guard<std::mutex> lock(callState->mutex);
bool established = (callState->establishedAt >= 0);
printf("\n=== Call Summary ===\n");
printf("Duration: %ds\n", duration);
if (dropRate > 0.0 || delayMaxMs > 0) {
printf("Signaling: drop=%.0f%% delay=%d-%dms\n",
dropRate * 100.0, delayMinMs, delayMaxMs);
}
if (mode == "reflector") {
printf("Mode: reflector (%s:%d)\n", reflectorHost.c_str(), reflectorPort);
} else {
printf("Mode: p2p\n");
}
printf("Caller state: %s\n", stateName(callState->callerState));
printf("Callee state: %s\n", stateName(callState->calleeState));
if (callState->establishedAt >= 0) {
printf("Call established: yes (at %.3fs)\n", callState->establishedAt);
} else {
printf("Call established: no\n");
}
bool bweNonZero = callerStats.hasNonZeroBwe && calleeStats.hasNonZeroBwe;
printf("Stats log: caller=%d callee=%d bitrate records\n",
callerStats.bitrateRecords, calleeStats.bitrateRecords);
printf("BWE non-zero: %s\n", bweNonZero ? "yes" : "no");
bool statsCollected = (callerStats.bitrateRecords > 0 && calleeStats.bitrateRecords > 0);
if (callState->errors.empty()) {
printf("Errors: none\n");
} else {
printf("Errors:\n");
for (const auto& err : callState->errors) {
printf(" - %s\n", err.c_str());
}
}
// Use _exit() to skip static destruction. ThreadLocalObject's destructor
// posts fire-and-forget cleanup tasks to the tgcalls media thread. If we
// return normally, static destruction tears down the StaticThreads thread
// pool while those tasks may still be executing, causing "pure virtual
// function called" when a half-destroyed object's vtable is accessed.
fflush(stdout);
fflush(stderr);
_exit(established && statsCollected && bweNonZero ? 0 : 1);
}
}
+105
View File
@@ -0,0 +1,105 @@
#!/usr/bin/env bash
# Run N parallel P2P tests locally and report aggregate results.
#
# Usage:
# ./run-local-test.sh # 100 calls, 15s each, 30% loss
# ./run-local-test.sh -n 1000 # 1000 calls
# ./run-local-test.sh -n 500 -j 200 # 500 calls, 200 parallel
# ./run-local-test.sh -n 100 -d 30 # 100 calls, 30s each
# ./run-local-test.sh --drop-rate 0.5 # 50% loss
set -euo pipefail
BINARY="./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli"
NUM=100
PARALLEL=150
DURATION=15
DROP_RATE=0.3
DELAY="50-200"
MODE="p2p"
VERSION="13.0.0"
while [[ $# -gt 0 ]]; do
case $1 in
-n) NUM="$2"; shift 2 ;;
-j) PARALLEL="$2"; shift 2 ;;
-d) DURATION="$2"; shift 2 ;;
--drop-rate) DROP_RATE="$2"; shift 2 ;;
--delay) DELAY="$2"; shift 2 ;;
--mode) MODE="$2"; shift 2 ;;
--version) VERSION="$2"; shift 2 ;;
*) echo "Usage: $0 [-n NUM] [-j PARALLEL] [-d DURATION] [--drop-rate RATE] [--delay MIN-MAX] [--mode MODE] [--version VER]"; exit 1 ;;
esac
done
if [ ! -x "$BINARY" ]; then
echo "Binary not found: $BINARY"
echo "Run: ./build-input/bazel-8.4.2 build //submodules/TgVoipWebrtc/tgcalls/tools/cli:tgcalls_cli"
exit 1
fi
TMPDIR=$(mktemp -d)
trap "rm -rf $TMPDIR" EXIT
echo "Running $NUM calls ($PARALLEL parallel, ${DURATION}s each, drop=${DROP_RATE}, delay=${DELAY}ms, mode=${MODE}, version=${VERSION})"
START=$(date +%s)
launched=0
wave=0
while [ $launched -lt $NUM ]; do
wave=$((wave + 1))
remaining=$((NUM - launched))
batch=$((remaining > PARALLEL ? PARALLEL : remaining))
pids=()
for i in $(seq 1 $batch); do
id=$((launched + i))
(
if "$BINARY" --mode "$MODE" --duration "$DURATION" \
--drop-rate "$DROP_RATE" --delay "$DELAY" --version "$VERSION" --quiet \
> /dev/null 2>&1; then
echo "pass" > "$TMPDIR/$id"
else
echo "fail" > "$TMPDIR/$id"
fi
) &
pids+=($!)
done
for pid in "${pids[@]}"; do
wait "$pid" 2>/dev/null || true
done
launched=$((launched + batch))
echo " Wave $wave: $launched/$NUM done"
done
END=$(date +%s)
ELAPSED=$((END - START))
# Tally
success=0
failed=0
for f in "$TMPDIR"/*; do
[ -f "$f" ] || continue
if [ "$(cat "$f")" = "pass" ]; then
success=$((success + 1))
else
failed=$((failed + 1))
fi
done
echo ""
echo "=== Local Mass Test Results ==="
echo "Total: $NUM"
echo "Success: $success"
echo "Failed: $failed"
if [ $NUM -gt 0 ]; then
rate=$(echo "scale=1; $success * 100 / $NUM" | bc)
echo "Rate: ${rate}%"
fi
echo "Duration: ${ELAPSED}s"
echo "Parallel: $PARALLEL"
exit 0
+249
View File
@@ -0,0 +1,249 @@
#!/usr/bin/env bash
# Launch N tgcalls test tasks on ECS Fargate, spread across reflectors.
#
# Usage:
# ./run-test.sh # 10 tasks, 30s each
# ./run-test.sh -n 100 # 100 tasks
# ./run-test.sh -n 50 -d 60 # 50 tasks, 60s each
# ./run-test.sh --results # fetch results from last run
set -euo pipefail
CLUSTER="tgcalls-test"
TASK_DEF="tgcalls-test"
REGION="eu-west-1"
SUBNETS="subnet-0292f49f3b4885428,subnet-09b8edab6eb20b837,subnet-0f464b5c62c9a6d1a"
SECURITY_GROUP="sg-0d87a1f19be76c160"
LOG_GROUP="/ecs/tgcalls-test"
REFLECTOR_URL="https://core.telegram.org/getReflectorList"
RUN_FILE="/tmp/tgcalls-last-run.txt"
STATUS_FILE="/tmp/tgcalls-last-status.txt"
NUM_TASKS=10
DURATION=30
usage() {
echo "Usage: $0 [-n NUM_TASKS] [-d DURATION_SECS] [--results]"
exit 1
}
fetch_results() {
if [ ! -f "$RUN_FILE" ]; then
echo "No run file found. Run a test first."
exit 1
fi
echo "Fetching results from last run..."
echo ""
TMPDIR_RESULTS=$(mktemp -d)
RESULTS_PARALLEL=20
total=$(wc -l < "$RUN_FILE" | tr -d ' ')
fetched=0
# Fetch logs in parallel
while IFS= read -r task_id; do
stream="tgcalls/tgcalls/${task_id}"
(aws logs get-log-events \
--log-group-name "$LOG_GROUP" \
--log-stream-name "$stream" \
--region "$REGION" \
--query 'events[*].message' \
--output text > "${TMPDIR_RESULTS}/${task_id}" 2>/dev/null || true) &
fetched=$((fetched + 1))
# Throttle: wait every RESULTS_PARALLEL calls
if [ $((fetched % RESULTS_PARALLEL)) -eq 0 ]; then
wait
echo -ne " Fetched $fetched/$total\r"
fi
done < "$RUN_FILE"
wait
echo " Fetched $total/$total"
echo ""
# Tally results
success=0
fail=0
errors=""
no_logs_tasks=()
for result_file in "${TMPDIR_RESULTS}"/*; do
[ -f "$result_file" ] || continue
task_id=$(basename "$result_file")
output=$(cat "$result_file")
if [ -z "$output" ]; then
# No logs yet — queue for retry
no_logs_tasks+=("$task_id")
elif echo "$output" | tr '\t' '\n' | grep -q "Audio received:.*yes" && echo "$output" | tr '\t' '\n' | grep -q "Call established:.*yes"; then
success=$((success + 1))
else
fail=$((fail + 1))
reflector=$(echo "$output" | tr '\t' '\n' | grep -o 'reflector ([^)]*' | sed 's/reflector (//' || echo "unknown")
errors="${errors}\n ${task_id}: reflector=${reflector}"
fi
done
rm -rf "$TMPDIR_RESULTS"
# Retry tasks that had no logs
if [ ${#no_logs_tasks[@]} -gt 0 ]; then
echo "Retrying ${#no_logs_tasks[@]} tasks with missing logs..."
sleep 5
for task_id in "${no_logs_tasks[@]}"; do
stream="tgcalls/tgcalls/${task_id}"
output=$(aws logs get-log-events \
--log-group-name "$LOG_GROUP" \
--log-stream-name "$stream" \
--region "$REGION" \
--query 'events[*].message' \
--output text 2>/dev/null || true)
if [ -n "$output" ] && echo "$output" | tr '\t' '\n' | grep -q "Audio received:.*yes" && echo "$output" | tr '\t' '\n' | grep -q "Call established:.*yes"; then
success=$((success + 1))
else
fail=$((fail + 1))
ecs_info=""
if [ -f "$STATUS_FILE" ]; then
ecs_info=$(grep "^${task_id}" "$STATUS_FILE" | head -1 | cut -f2-3)
fi
if [ -n "$ecs_info" ]; then
errors="${errors}\n ${task_id}: exit=${ecs_info}"
else
errors="${errors}\n ${task_id} (no logs, no ECS status)"
fi
fi
done
echo ""
fi
echo "=== Test Results ==="
echo "Total tasks: $total"
echo "Success: $success"
echo "Failed: $fail"
if [ -n "$errors" ]; then
echo -e "\nFailed tasks:${errors}"
fi
exit 0
}
# Parse args
while [[ $# -gt 0 ]]; do
case $1 in
-n) NUM_TASKS="$2"; shift 2 ;;
-d) DURATION="$2"; shift 2 ;;
--results) fetch_results ;;
*) usage ;;
esac
done
# Fetch reflector list — IPs only (port randomized by CLI)
echo "Fetching reflector list..."
REFLECTOR_CSV=$(curl -s "$REFLECTOR_URL" | cut -d: -f1 | sort -u | tr '\n' ',' | sed 's/,$//')
NUM_REFLECTORS=$(echo "$REFLECTOR_CSV" | tr ',' '\n' | wc -l | tr -d ' ')
echo "Got $NUM_REFLECTORS unique reflector IPs"
# To inject bad addresses for testing, uncomment:
# NUM_BAD=$(( NUM_REFLECTORS / 9 ))
# BAD_CSV=$(for i in $(seq 1 $NUM_BAD); do echo -n "10.255.255.$((i % 256)):1,"; done | sed 's/,$//')
# REFLECTOR_CSV="${REFLECTOR_CSV},${BAD_CSV}"
# echo "Injected $NUM_BAD bad addresses (~10% of pool)"
echo "Launching $NUM_TASKS tasks (${DURATION}s each), each picks a random reflector..."
echo ""
# Clear run files
> "$RUN_FILE"
> "$STATUS_FILE"
# Launch in waves of WAVE_SIZE, waiting for each wave to complete before the next.
# Within each wave, fire PARALLEL API calls concurrently (each launching up to 10 tasks).
WAVE_SIZE=500
PARALLEL=10
TMPDIR_LAUNCH=$(mktemp -d)
remaining=$NUM_TASKS
wave=0
while [ $remaining -gt 0 ]; do
wave=$((wave + 1))
wave_target=$((remaining > WAVE_SIZE ? WAVE_SIZE : remaining))
wave_arns=()
wave_launched=0
echo "=== Wave $wave: launching $wave_target tasks ==="
while [ $wave_launched -lt $wave_target ]; do
pids=()
api_calls=0
for p in $(seq 1 $PARALLEL); do
left=$((wave_target - wave_launched - api_calls * 10))
[ $left -le 0 ] && break
batch=$((left > 10 ? 10 : left))
outfile="${TMPDIR_LAUNCH}/batch_${wave}_${wave_launched}_${p}"
api_calls=$((api_calls + 1))
(aws ecs run-task --region "$REGION" \
--cluster "$CLUSTER" \
--task-definition "$TASK_DEF" \
--launch-type FARGATE \
--count "$batch" \
--network-configuration "awsvpcConfiguration={subnets=[${SUBNETS}],securityGroups=[${SECURITY_GROUP}],assignPublicIp=ENABLED}" \
--overrides "{\"containerOverrides\":[{\"name\":\"tgcalls\",\"command\":[\"--quiet\",\"--reflector-list\",\"${REFLECTOR_CSV}\",\"--duration\",\"${DURATION}\",\"--drop-rate\",\"0.3\",\"--delay\",\"50-200\"]}]}" \
--query 'tasks[*].taskArn' --output text > "$outfile" 2>&1) &
pids+=($!)
done
for pid in "${pids[@]}"; do
wait "$pid" 2>/dev/null || true
done
for outfile in "${TMPDIR_LAUNCH}"/batch_*; do
[ -f "$outfile" ] || continue
while read -r arn; do
if [[ "$arn" == arn:* ]]; then
task_id="${arn##*/}"
wave_arns+=("$arn")
echo "$task_id" >> "$RUN_FILE"
wave_launched=$((wave_launched + 1))
fi
done < <(tr '\t' '\n' < "$outfile")
rm -f "$outfile"
done
echo " Launched $wave_launched/$wave_target in wave $wave"
done
remaining=$((remaining - wave_launched))
echo " Waiting for wave $wave ($wave_launched tasks) to finish..."
# Wait in batches of 100
for ((start=0; start<${#wave_arns[@]}; start+=100)); do
batch=("${wave_arns[@]:$start:100}")
aws ecs wait tasks-stopped \
--cluster "$CLUSTER" \
--tasks "${batch[@]}" \
--region "$REGION" 2>/dev/null || true
done
# Collect ECS task status while data is fresh (expires after ~1hr)
echo " Collecting task status for wave $wave..."
for ((start=0; start<${#wave_arns[@]}; start+=100)); do
batch=("${wave_arns[@]:$start:100}")
aws ecs describe-tasks --cluster "$CLUSTER" --tasks "${batch[@]}" --region "$REGION" \
--query 'tasks[*].[containers[0].taskArn,containers[0].exitCode,stoppedReason]' \
--output text 2>/dev/null | while IFS=$'\t' read -r arn exit_code reason; do
task_id="${arn##*/}"
echo -e "${task_id}\t${exit_code}\t${reason}" >> "$STATUS_FILE"
done
done
echo " Wave $wave complete."
echo ""
done
rm -rf "$TMPDIR_LAUNCH"
total_launched=$(wc -l < "$RUN_FILE" | tr -d ' ')
echo "Launched $total_launched/$NUM_TASKS total tasks."
echo "Run '$0 --results' to see results."
+93
View File
@@ -0,0 +1,93 @@
#!/usr/bin/env bash
# Run parallel tests, stop on first crash (non-zero exit).
set -euo pipefail
BINARY="./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli"
PARALLEL=250
DURATION=15
VERSION="11.0.0"
DROP_RATE=0.3
DELAY="50-200"
MODE="p2p"
while [[ $# -gt 0 ]]; do
case $1 in
-j) PARALLEL="$2"; shift 2 ;;
-d) DURATION="$2"; shift 2 ;;
--version) VERSION="$2"; shift 2 ;;
--drop-rate) DROP_RATE="$2"; shift 2 ;;
--delay) DELAY="$2"; shift 2 ;;
--mode) MODE="$2"; shift 2 ;;
*) echo "Usage: $0 [-j PARALLEL] [-d DURATION] [--version VER] [--drop-rate R] [--delay D] [--mode M]"; exit 1 ;;
esac
done
if [ ! -x "$BINARY" ]; then
echo "Binary not found: $BINARY"
exit 1
fi
TMPDIR=$(mktemp -d)
trap "rm -rf $TMPDIR" EXIT
echo "Running waves of $PARALLEL until first crash (${DURATION}s, drop=${DROP_RATE}, delay=${DELAY}, version=${VERSION})"
wave=0
total=0
while true; do
wave=$((wave + 1))
pids=()
for i in $(seq 1 $PARALLEL); do
id=$((total + i))
(
set +e
"$BINARY" --mode "$MODE" --duration "$DURATION" \
--drop-rate "$DROP_RATE" --delay "$DELAY" --version "$VERSION" --quiet \
> "$TMPDIR/${id}.out" 2>"$TMPDIR/${id}.err"
echo $? > "$TMPDIR/${id}.rc"
) &
pids+=($!)
done
for pid in "${pids[@]}"; do
wait "$pid" 2>/dev/null || true
done
total=$((total + PARALLEL))
# Check for crashes
crashes=0
for i in $(seq $((total - PARALLEL + 1)) $total); do
rc_file="$TMPDIR/${i}.rc"
if [ ! -f "$rc_file" ]; then
crashes=$((crashes + 1))
echo ""
echo "=== CRASH in run $i (no rc file) ==="
echo "--- stderr ---"
cat "$TMPDIR/${i}.err" 2>/dev/null || echo "(empty)"
else
rc=$(cat "$rc_file")
if [ "$rc" -gt 128 ] 2>/dev/null; then
crashes=$((crashes + 1))
echo ""
echo "=== CRASH in run $i (exit $rc) ==="
echo "--- stderr ---"
cat "$TMPDIR/${i}.err" 2>/dev/null || echo "(empty)"
echo "--- stdout ---"
cat "$TMPDIR/${i}.out" 2>/dev/null || echo "(empty)"
# Only show first crash in detail
if [ $crashes -eq 1 ]; then
echo "=== END CRASH ==="
fi
fi
fi
done
if [ $crashes -gt 0 ]; then
echo ""
echo "Wave $wave: $crashes crashes in $PARALLEL runs (total $total runs)"
exit 1
fi
echo " Wave $wave: $PARALLEL/$PARALLEL passed (total $total)"
done
+18
View File
@@ -0,0 +1,18 @@
load("@io_bazel_rules_go//go:def.bzl", "go_binary")
go_binary(
name = "go_sfu",
srcs = glob(["*.go"]),
cgo = True,
linkmode = "c-archive",
visibility = ["//visibility:public"],
deps = [
"@com_github_pion_datachannel//:datachannel",
"@com_github_pion_dtls_v3//:dtls",
"@com_github_pion_ice_v4//:ice",
"@com_github_pion_logging//:logging",
"@com_github_pion_rtcp//:rtcp",
"@com_github_pion_sctp//:sctp",
"@com_github_pion_srtp_v3//:srtp",
],
)
+107
View File
@@ -0,0 +1,107 @@
# Go/Pion SFU
The group call test mode uses a Go-based SFU (Selective Forwarding Unit) built with [Pion WebRTC](https://github.com/pion/webrtc), linked into the C++ `tgcalls_cli` binary via CGo.
## Build Integration
- `MODULE.bazel``rules_go` 0.60.0 + Go SDK 1.24.2 + `gazelle` 0.43.0; Pion dependencies managed via `go_deps` Gazelle extension
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/BUILD``go_binary` with `linkmode = "c-archive"` produces a static archive + CGo header, exposes `CcInfo` to C++ targets
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go.mod` / `go.sum` — Pion dependency declarations (pion/ice, pion/dtls, pion/srtp, pion/sctp)
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/BUILD` — depends on `//submodules/TgVoipWebrtc/tgcalls/tools/go_sfu` to link the Go archive
- The CGo-generated header is included as `#include "submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go_sfu.h"` in C++ code
## How It Works
The `go_binary` with `linkmode = "c-archive"` compiles Go code (including the Go runtime) into a `.a` static archive. Functions annotated with `//export` in Go become C-callable symbols. Bazel's `rules_go` automatically provides `CcInfo`, so `cc_binary` targets can depend on the Go archive via `deps` — no manual linkopts needed.
The Go runtime (GC, goroutine scheduler) runs inside the C++ process. This adds ~10MB memory overhead. `GoSfu_Init()` must be called before any other Go functions.
## Key Files
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/sfu.go` — SFU core: participant registry, join/leave/response flow, audio+video RTP forwarding, SSRC registry (audio/video/video-rtx with layer index), Colibri `ReceiverVideoConstraints`/`SenderVideoConstraints` handling, PLI/FIR forwarding, `ActiveAudioSsrcs`/`ActiveVideoSsrcs` broadcasting, `//export` C bindings (`GoSfu_Init`, `GoSfu_Create`, `GoSfu_Destroy`, `GoSfu_Join`, `GoSfu_Leave`, `GoSfu_QuerySsrc`, `GoSfu_QueryVideoSsrcs`, `GoSfu_Free`, `GoSfu_Shutdown`)
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/participant.go` — per-participant transport stack (ICE agent, DTLS conn, SRTP session, SRTCP contexts for manual RTCP decrypt/encrypt, SCTP association, data channel send/receive, per-receiver video layer selection)
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/mux.go` — packet demuxer: three-way split of ICE traffic into DTLS handshake, SRTP (RTP), and SRTCP (RTCP) channels per RFC 7983 + RFC 5761
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/go.mod` / `go.sum` — Go module with Pion dependencies
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/group_mode.cpp` — C++ side that drives the group join flow and calls into Go SFU
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/group_mode.h` — header for group mode entry point
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/group_participant.h/.cpp` — shared participant lifecycle helpers (`createParticipant`, `stopParticipant`, `validateGroupState`, `printGroupSummary`), `ParticipantState` struct, audio helpers
- `submodules/TgVoipWebrtc/tgcalls/tools/cli/group_churn_mode.h/.cpp` — group-churn stress test: base group + rapid join/leave cycling
## SFU Bandwidth Adaptation
The SFU implements REMB-based bandwidth-adaptive simulcast layer selection for video. Per receiver, it maintains an EWMA-smoothed bandwidth estimate from REMB RTCP feedback and uses a `LayerSelector` state machine per (receiver, sender) pair to decide which simulcast layer to forward.
### State Machine
- **STABLE**: forwarding current layer. Checks for upswitch opportunity (REMB > threshold × 1.2) or downswitch need (REMB < threshold × 0.7).
- **PROBING_UP**: ramping RTX padding from 0 to the gap between current and target layer bitrate over 2 seconds. Aborts if REMB drops; succeeds if REMB sustains.
- **GRACE_DOWN**: REMB below downswitch threshold. Waits 500ms, then downswitches if not recovered. 5-second cooldown after any switch.
### Layer Thresholds
| Layer | Nominal | Upswitch When | Downswitch When |
|-------|---------|--------------|-----------------|
| 0 | 60 kbps | (start) | (never) |
| 1 | 110 kbps | BW > 132 kbps | BW < 77 kbps |
| 2 | 900 kbps | BW > 1,080 kbps | BW < 630 kbps |
### Layer Selection and SSRC Rewriting
The SFU forwards exactly one simulcast layer per (receiver, sender) pair. Before `ReceiverVideoConstraints` arrives, the SFU uses `requestedLayer` as the cap and forwards at `maxActiveLayer` (the highest layer the encoder actually produces). After constraints arrive, `ensureLayerSelector` sets `selectedLayer` clamped to `maxActiveLayer`.
When forwarding a non-base layer, the SFU rewrites the RTP SSRC to the primary (layer 0) SSRC. This is necessary because `IncomingVideoChannel` in CustomImpl attaches its `VideoSinkImpl` to `_mainVideoSsrc` (the first SSRC in the SIM group, i.e., layer 0). Without SSRC rewriting, packets from higher layers are delivered to the wrong receive stream and never decoded. RTX SSRCs are similarly rewritten to the layer 0 FID SSRC.
### Testing on Localhost
Use `--network-scenario step-down-up` to exercise the full adaptation path via per-client network simulation (replaces the old REMB-override `--bw-scenario`).
```bash
# Network scenario test (30s, 4 phases: uncapped → 80k egress → 200k → uncapped)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 2 --video --duration 30 --network-scenario step-down-up
```
Unit tests drive the `LayerSelector` state machine directly via mocked callbacks. Run from `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/`:
```bash
go test -run TestLayerSelector -v -timeout 60s
```
Covers upswitch L0→L1→L2, downswitch L2→L1→L0, grace-down recovery on transient dips, stale-BW idle behavior, the `OnMaxActiveLayerIncreased` fallback used when clients don't send REMB, and `maxLayer` enforcement.
### REMB-free Fallback
Real tgcalls clients negotiate `goog-remb` but use transport-cc as the primary BWE signal, so no REMB actually arrives at the SFU. This means the REMB-driven state machine never enters `PROBING_UP` in live runs. `LayerSelector.OnMaxActiveLayerIncreased(maxActive)` is the fallback: when the sender starts producing a higher simulcast layer than previously seen AND the BW estimate is stale, the SFU immediately upshifts to the highest available layer (clamped by `maxLayer`). Called from `sfu.go`'s packet-forwarding path whenever `maxActiveLayer[senderID]` is bumped.
### Key Files
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/bandwidth.go``BandwidthEstimator`, `LayerSelector`, `RtxRingBuffer`, `OnMaxActiveLayerIncreased`
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/bandwidth_test.go` — unit tests for `LayerSelector` up/down transitions
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/participant.go` — REMB parsing in `readRTCPLoop()`, `selectedLayers`
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/sfu.go` — layer-filtered forwarding, `ensureLayerSelector`, SSRC rewriting, `maxActiveLayer` tracking
## SFU Transport-CC Feedback
The SFU generates RTCP transport-cc feedback (type 205, FMT 15) per sender every 100ms. This provides the sender's GCC (Google Congestion Control) with packet arrival data, enabling BWE ramp-up so the encoder produces higher simulcast layers.
The feedback reflects actual (or simulated) packet arrivals — if ingress network simulation drops packets, the feedback reports them as missing, causing the sender's GCC to reduce bitrate.
### How It Works
1. Each incoming RTP packet is parsed for the transport-wide sequence number (header extension ID 3, one-byte RFC 5285 format)
2. `TransportCCGenerator.RecordArrival(twccSeq)` records the arrival time
3. Every 100ms, `emitFeedback()` builds an `rtcp.TransportLayerCC` packet with `PacketChunks` (run-length or status-vector encoding) and `RecvDeltas` (250µs units)
4. The feedback is marshalled, encrypted via SRTCP, and sent to the sender
5. The sender's `Call::Receiver::DeliverRtcpPacket()` feeds it to the GCC via `GroupNetworkManager::OnRtcpPacketReceived_n``_call->Receiver()->DeliverRtcpPacket()`
### Current Status
Transport-cc feedback is working: the SFU records ~60-70 arrivals per first 100ms tick, generates feedback packets (32-128 bytes), and the sender receives them. The GCC ramps from the 400kbps start bitrate to produce layer 1 (640x360). Full ramp to layer 2 (1280x720, needs ~1Mbps) requires further investigation — the GCC may need probing support or the `adjustBitratePreferences` max_bitrate_bps of 1052kbps may be a bottleneck.
### Key Files
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/twcc.go``TransportCCGenerator`, `parseTWCCSeq` (RTP header extension parser)
## SFU Network Simulation
Per-client network simulation with independent ingress (from client) and egress (to client) simulators. Each direction has: delay, jitter, packet loss, and bandwidth cap (token bucket).
```bash
# Configure via CGo: GoSfu_SetNetworkParams(handle, participantID, direction, delayMs, jitterMs, dropRate, bandwidthBps)
# direction: 0 = ingress, 1 = egress
# Network scenario test (4 phases: uncapped -> 80k -> 200k -> uncapped)
./bazel-bin/submodules/TgVoipWebrtc/tgcalls/tools/cli/tgcalls_cli --mode group --participants 2 --video --duration 30 --network-scenario step-down-up
```
### Key Files
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/network_sim.go``NetworkSimulator` (token bucket, delay, jitter, drop)
- `submodules/TgVoipWebrtc/tgcalls/tools/go_sfu/participant.go``ingressSim`, `egressSim` on each `Participant`
+475
View File
@@ -0,0 +1,475 @@
package main
import (
"sync"
"time"
)
// --- Bandwidth Estimation ---
const (
ewmaAlpha = 0.3
safetyFactor = 0.85
stalenessTTL = 5 * time.Second
)
// BandwidthEstimator maintains an EWMA-smoothed REMB estimate for a receiver.
type BandwidthEstimator struct {
mu sync.Mutex
lastREMBBps float64
smoothedBps float64
lastREMBAt time.Time
}
// OnREMB feeds a new REMB value (in bits per second) into the estimator.
func (e *BandwidthEstimator) OnREMB(bps float64) {
e.mu.Lock()
defer e.mu.Unlock()
e.lastREMBBps = bps
e.lastREMBAt = time.Now()
if e.smoothedBps == 0 {
e.smoothedBps = bps
} else {
e.smoothedBps = ewmaAlpha*bps + (1-ewmaAlpha)*e.smoothedBps
}
}
// EffectiveBps returns the safe bandwidth estimate in bps.
// Returns -1 if the estimate is stale (no REMB for stalenessTTL).
func (e *BandwidthEstimator) EffectiveBps() float64 {
e.mu.Lock()
defer e.mu.Unlock()
if e.lastREMBAt.IsZero() || time.Since(e.lastREMBAt) > stalenessTTL {
return -1
}
return e.smoothedBps * safetyFactor
}
// SmoothedBps returns the raw EWMA value (for logging).
func (e *BandwidthEstimator) SmoothedBps() float64 {
e.mu.Lock()
defer e.mu.Unlock()
return e.smoothedBps
}
// LastREMBBps returns the last raw REMB value (for logging).
func (e *BandwidthEstimator) LastREMBBps() float64 {
e.mu.Lock()
defer e.mu.Unlock()
return e.lastREMBBps
}
// --- Layer Bitrate Model ---
// LayerBitrate holds the thresholds for one simulcast layer.
type LayerBitrate struct {
Nominal float64 // typical sustained bitrate (bps)
UpThresh float64 // effective BW must exceed this to upswitch TO this layer
DownThresh float64 // effective BW must drop below this to downswitch FROM this layer
}
// layerBitrates defines the 3 simulcast layers matching tgcalls adjustVideoSendParams().
// Layer 0 has no downThresh (always viable) and no upThresh (start here).
var layerBitrates = [3]LayerBitrate{
{Nominal: 60_000, UpThresh: 0, DownThresh: 0}, // layer 0: 160x90
{Nominal: 110_000, UpThresh: 132_000, DownThresh: 77_000}, // layer 1: 320x180
{Nominal: 900_000, UpThresh: 1_080_000, DownThresh: 630_000}, // layer 2: 640x360
}
// --- RTX Ring Buffer ---
// RtxEntry stores one video RTP packet for potential retransmission as RTX padding.
type RtxEntry struct {
Payload []byte
SeqNum uint16
Timestamp uint32
}
// RtxRingBuffer is a per-sender circular buffer of recent video RTP packets.
type RtxRingBuffer struct {
mu sync.Mutex
entries []RtxEntry
head int
count int
cap int
}
// NewRtxRingBuffer creates a ring buffer with the given capacity.
func NewRtxRingBuffer(capacity int) *RtxRingBuffer {
return &RtxRingBuffer{
entries: make([]RtxEntry, capacity),
cap: capacity,
}
}
// Push adds a video RTP packet to the ring buffer.
// payload is copied so the caller can reuse their buffer.
func (r *RtxRingBuffer) Push(payload []byte, seqNum uint16, timestamp uint32) {
r.mu.Lock()
defer r.mu.Unlock()
entry := &r.entries[r.head]
if cap(entry.Payload) >= len(payload) {
entry.Payload = entry.Payload[:len(payload)]
} else {
entry.Payload = make([]byte, len(payload))
}
copy(entry.Payload, payload)
entry.SeqNum = seqNum
entry.Timestamp = timestamp
r.head = (r.head + 1) % r.cap
if r.count < r.cap {
r.count++
}
}
// Get returns up to n most recent packets (oldest first).
func (r *RtxRingBuffer) Get(n int) []RtxEntry {
r.mu.Lock()
defer r.mu.Unlock()
if n > r.count {
n = r.count
}
if n == 0 {
return nil
}
result := make([]RtxEntry, n)
start := (r.head - r.count + r.cap) % r.cap // oldest entry
readFrom := (start + r.count - n + r.cap) % r.cap // start of the n most recent
for i := 0; i < n; i++ {
idx := (readFrom + i) % r.cap
src := &r.entries[idx]
entry := RtxEntry{
Payload: make([]byte, len(src.Payload)),
SeqNum: src.SeqNum,
Timestamp: src.Timestamp,
}
copy(entry.Payload, src.Payload)
result[i] = entry
}
return result
}
// rtxEncapsulate wraps an original RTP payload into an RTX packet payload per RFC 4588.
// The RTX payload is: [2-byte original sequence number] + [original RTP payload (after header)].
// The caller is responsible for setting the RTX SSRC and incrementing RTX sequence number
// on the outer RTP header.
func rtxEncapsulate(originalPayload []byte, originalSeqNum uint16) []byte {
out := make([]byte, 2+len(originalPayload))
out[0] = byte(originalSeqNum >> 8)
out[1] = byte(originalSeqNum)
copy(out[2:], originalPayload)
return out
}
// --- Layer Selector State Machine ---
type selectorState int
const (
stateStable selectorState = iota
stateProbingUp
stateGraceDown
)
func (s selectorState) String() string {
switch s {
case stateStable:
return "STABLE"
case stateProbingUp:
return "PROBING_UP"
case stateGraceDown:
return "GRACE_DOWN"
default:
return "UNKNOWN"
}
}
const (
probeDuration = 2 * time.Second
graceDownTimeout = 500 * time.Millisecond
cooldownDuration = 5 * time.Second
tickInterval = 100 * time.Millisecond
)
// LayerSelectorCallbacks provides the hooks the state machine needs into the SFU.
type LayerSelectorCallbacks struct {
// GetEffectiveBW returns the receiver's current effective bandwidth (bps), or -1 if stale.
GetEffectiveBW func() float64
// SetSelectedLayer updates the forwarding layer for this (receiver, sender) pair.
SetSelectedLayer func(layer int)
// SendPLI sends a PLI to the sender for the given SSRC.
SendPLI func(ssrc uint32)
// GetSenderVideoLayers returns the sender's simulcast layers.
GetSenderVideoLayers func() []SimulcastLayer
// GetRtxBuffer returns the sender's RTX ring buffer.
GetRtxBuffer func() *RtxRingBuffer
// SendRtxPadding sends an RTX padding packet to the receiver.
// rtxSSRC is the FID SSRC, seqNum is the RTX sequence number.
SendRtxPadding func(rtxPayload []byte, rtxSSRC uint32, seqNum uint16, timestamp uint32)
// Log emits a log message.
Log func(level string, format string, args ...interface{})
}
// LayerSelector manages the state machine for one (receiver, sender) pair.
type LayerSelector struct {
mu sync.Mutex
receiverID int
senderID int
currentLayer int
maxLayer int // max layer the receiver requested
state selectorState
callbacks LayerSelectorCallbacks
// Probing state
probeTarget int // layer we're probing toward
probeStartTime time.Time
probeRtxSeq uint16 // incrementing RTX sequence number for padding
// Grace-down state
graceStartTime time.Time
// Cooldown
lastSwitchTime time.Time
// Control
stopCh chan struct{}
done chan struct{}
}
// NewLayerSelector creates and starts a new LayerSelector.
// initialLayer is the layer to start forwarding (typically = requestedLayer).
func NewLayerSelector(receiverID, senderID, initialLayer, maxLayer int, cb LayerSelectorCallbacks) *LayerSelector {
ls := &LayerSelector{
receiverID: receiverID,
senderID: senderID,
currentLayer: initialLayer,
maxLayer: maxLayer,
state: stateStable,
callbacks: cb,
stopCh: make(chan struct{}),
done: make(chan struct{}),
}
go ls.run()
return ls
}
// Stop terminates the selector's tick loop.
func (ls *LayerSelector) Stop() {
close(ls.stopCh)
<-ls.done
}
// SetMaxLayer updates the maximum layer the receiver wants (from ReceiverVideoConstraints).
func (ls *LayerSelector) SetMaxLayer(maxLayer int) {
ls.mu.Lock()
defer ls.mu.Unlock()
ls.maxLayer = maxLayer
// If current layer exceeds new max, downswitch immediately.
if ls.currentLayer > maxLayer {
ls.switchLayer(maxLayer)
}
}
// OnMaxActiveLayerIncreased is called when the sender starts producing a
// higher simulcast layer than previously observed. If the BW estimate is
// stale (no REMB arriving — common when clients use transport-cc exclusively
// and the SFU hasn't generated REMB), upshift immediately up to maxLayer so
// the receiver gets the best available layer. When REMB is fresh, the state
// machine is in charge and this is a no-op.
func (ls *LayerSelector) OnMaxActiveLayerIncreased(maxActive int) {
ls.mu.Lock()
defer ls.mu.Unlock()
if ls.callbacks.GetEffectiveBW() >= 0 {
// BW estimate available — state machine decides.
return
}
target := maxActive
if target > ls.maxLayer {
target = ls.maxLayer
}
if target > ls.currentLayer {
ls.switchLayer(target)
}
}
// CurrentLayer returns the currently selected layer.
func (ls *LayerSelector) CurrentLayer() int {
ls.mu.Lock()
defer ls.mu.Unlock()
return ls.currentLayer
}
func (ls *LayerSelector) run() {
defer close(ls.done)
ticker := time.NewTicker(tickInterval)
defer ticker.Stop()
for {
select {
case <-ls.stopCh:
return
case <-ticker.C:
ls.tick()
}
}
}
func (ls *LayerSelector) tick() {
ls.mu.Lock()
defer ls.mu.Unlock()
effectiveBW := ls.callbacks.GetEffectiveBW()
if effectiveBW < 0 {
// Stale estimate — do nothing.
return
}
switch ls.state {
case stateStable:
ls.tickStable(effectiveBW)
case stateProbingUp:
ls.tickProbingUp(effectiveBW)
case stateGraceDown:
ls.tickGraceDown(effectiveBW)
}
}
func (ls *LayerSelector) tickStable(effectiveBW float64) {
// Check for upswitch opportunity.
nextLayer := ls.currentLayer + 1
if nextLayer <= ls.maxLayer && nextLayer <= 2 {
if !ls.inCooldown() && effectiveBW > layerBitrates[nextLayer].UpThresh {
ls.state = stateProbingUp
ls.probeTarget = nextLayer
ls.probeStartTime = time.Now()
ls.callbacks.Log("INFO", "Participant %d<-%d: STABLE->PROBING_UP (BW=%.0fkbps, target=layer%d@%.0fkbps)",
ls.receiverID, ls.senderID, effectiveBW/1000, nextLayer, layerBitrates[nextLayer].UpThresh/1000)
return
}
}
// Check for downswitch need.
if ls.currentLayer > 0 {
if effectiveBW < layerBitrates[ls.currentLayer].DownThresh {
ls.state = stateGraceDown
ls.graceStartTime = time.Now()
ls.callbacks.Log("INFO", "Participant %d<-%d: STABLE->GRACE_DOWN (BW=%.0fkbps, thresh=%.0fkbps)",
ls.receiverID, ls.senderID, effectiveBW/1000, layerBitrates[ls.currentLayer].DownThresh/1000)
return
}
}
}
func (ls *LayerSelector) tickProbingUp(effectiveBW float64) {
elapsed := time.Since(ls.probeStartTime)
// Abort if bandwidth dropped below current layer's nominal bitrate.
if effectiveBW < layerBitrates[ls.currentLayer].Nominal {
ls.state = stateStable
ls.lastSwitchTime = time.Now() // enter cooldown
ls.callbacks.Log("INFO", "Participant %d<-%d: PROBING_UP->STABLE (abort, BW=%.0fkbps < nominal=%.0fkbps)",
ls.receiverID, ls.senderID, effectiveBW/1000, layerBitrates[ls.currentLayer].Nominal/1000)
return
}
// Probe complete — switch up.
if elapsed >= probeDuration {
if effectiveBW > layerBitrates[ls.probeTarget].Nominal {
ls.callbacks.Log("INFO", "Participant %d<-%d: PROBING_UP->STABLE (success, switching to layer %d)",
ls.receiverID, ls.senderID, ls.probeTarget)
ls.switchLayer(ls.probeTarget)
return
}
// BW not sufficient at end of probe — abort.
ls.state = stateStable
ls.lastSwitchTime = time.Now()
ls.callbacks.Log("INFO", "Participant %d<-%d: PROBING_UP->STABLE (probe done but BW=%.0fkbps insufficient)",
ls.receiverID, ls.senderID, effectiveBW/1000)
return
}
// Send RTX padding during probe.
ls.sendProbePadding(elapsed)
}
func (ls *LayerSelector) tickGraceDown(effectiveBW float64) {
// If bandwidth recovered, cancel grace period.
if effectiveBW >= layerBitrates[ls.currentLayer].DownThresh {
ls.state = stateStable
ls.callbacks.Log("INFO", "Participant %d<-%d: GRACE_DOWN->STABLE (recovered, BW=%.0fkbps)",
ls.receiverID, ls.senderID, effectiveBW/1000)
return
}
// Grace period expired — downswitch.
if time.Since(ls.graceStartTime) >= graceDownTimeout {
targetLayer := ls.currentLayer - 1
if targetLayer < 0 {
targetLayer = 0
}
ls.callbacks.Log("INFO", "Participant %d<-%d: GRACE_DOWN->STABLE (downswitch to layer %d)",
ls.receiverID, ls.senderID, targetLayer)
ls.switchLayer(targetLayer)
}
}
func (ls *LayerSelector) switchLayer(newLayer int) {
oldLayer := ls.currentLayer
ls.currentLayer = newLayer
ls.state = stateStable
ls.lastSwitchTime = time.Now()
ls.callbacks.SetSelectedLayer(newLayer)
// Request keyframe at the new layer.
layers := ls.callbacks.GetSenderVideoLayers()
if newLayer < len(layers) {
ls.callbacks.SendPLI(layers[newLayer].SSRC)
ls.callbacks.Log("INFO", "Participant %d<-%d: switched layer %d->%d (PLI sent for SSRC=%d)",
ls.receiverID, ls.senderID, oldLayer, newLayer, layers[newLayer].SSRC)
}
}
func (ls *LayerSelector) inCooldown() bool {
return !ls.lastSwitchTime.IsZero() && time.Since(ls.lastSwitchTime) < cooldownDuration
}
func (ls *LayerSelector) sendProbePadding(elapsed time.Duration) {
// Calculate target padding rate: ramp from 0 to gap over probeDuration.
gap := layerBitrates[ls.probeTarget].Nominal - layerBitrates[ls.currentLayer].Nominal
progress := float64(elapsed) / float64(probeDuration)
targetBps := gap * progress
// How many bytes to send in this 100ms tick.
bytesPerTick := targetBps / 8 / (float64(time.Second) / float64(tickInterval))
rtxBuf := ls.callbacks.GetRtxBuffer()
if rtxBuf == nil {
return
}
// Pull packets from the ring buffer to fill the target bytes.
entries := rtxBuf.Get(20) // enough for one tick
if len(entries) == 0 {
return
}
layers := ls.callbacks.GetSenderVideoLayers()
if ls.currentLayer >= len(layers) {
return
}
rtxSSRC := layers[ls.currentLayer].FidSSRC
if rtxSSRC == 0 {
return
}
var sentBytes float64
entryIdx := 0
for sentBytes < bytesPerTick && entryIdx < len(entries) {
entry := entries[entryIdx]
entryIdx++
rtxPayload := rtxEncapsulate(entry.Payload, entry.SeqNum)
ls.probeRtxSeq++
ls.callbacks.SendRtxPadding(rtxPayload, rtxSSRC, ls.probeRtxSeq, entry.Timestamp)
sentBytes += float64(len(rtxPayload))
}
}
+249
View File
@@ -0,0 +1,249 @@
package main
import (
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
// mockCallbacks is a test harness for driving the LayerSelector. BW is
// atomic so the selector's run() goroutine can read while the test writes.
type mockCallbacks struct {
bw atomic.Int64 // current effective bandwidth in bps; negative = stale
layers []SimulcastLayer
selectedLayer atomic.Int32
pliCount atomic.Int32
probePaddings atomic.Int32
logMu sync.Mutex
logBuf []string
}
func newMockCallbacks(layers []SimulcastLayer, initialBW float64) *mockCallbacks {
m := &mockCallbacks{layers: layers}
m.bw.Store(int64(initialBW))
m.selectedLayer.Store(-1)
return m
}
func (m *mockCallbacks) setBW(bps float64) { m.bw.Store(int64(bps)) }
func (m *mockCallbacks) currentSelected() int { return int(m.selectedLayer.Load()) }
func (m *mockCallbacks) toCallbacks() LayerSelectorCallbacks {
return LayerSelectorCallbacks{
GetEffectiveBW: func() float64 {
v := float64(m.bw.Load())
if v < 0 {
return -1
}
return v
},
SetSelectedLayer: func(layer int) {
m.selectedLayer.Store(int32(layer))
},
SendPLI: func(ssrc uint32) {
m.pliCount.Add(1)
},
GetSenderVideoLayers: func() []SimulcastLayer {
return m.layers
},
GetRtxBuffer: func() *RtxRingBuffer {
return nil // probing padding no-ops without a buffer
},
SendRtxPadding: func(rtxPayload []byte, rtxSSRC uint32, seqNum uint16, timestamp uint32) {
m.probePaddings.Add(1)
},
Log: func(level string, format string, args ...interface{}) {
m.logMu.Lock()
m.logBuf = append(m.logBuf, fmt.Sprintf("["+level+"] "+format, args...))
m.logMu.Unlock()
},
}
}
// waitForLayer polls the selector's currentLayer up to timeout for a change
// to `want`. Returns true if reached, false on timeout.
func waitForLayer(ls *LayerSelector, want int, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if ls.CurrentLayer() == want {
return true
}
time.Sleep(20 * time.Millisecond)
}
return false
}
func testLayers() []SimulcastLayer {
return []SimulcastLayer{
{SSRC: 1001, FidSSRC: 1002},
{SSRC: 1003, FidSSRC: 1004},
{SSRC: 1005, FidSSRC: 1006},
}
}
// TestLayerSelectorUpswitch verifies L0 -> L1 -> L2 based on rising BW.
//
// Thresholds (from layerBitrates):
//
// L1 UpThresh = 132 kbps → needs REMB > ~155 kbps (with 0.85 safety factor)
// L2 UpThresh = 1080 kbps → needs REMB > ~1271 kbps
//
// The selector's state machine enforces a 5s cooldown after each switch, so
// the whole test runs in ~8-10 seconds.
func TestLayerSelectorUpswitch(t *testing.T) {
m := newMockCallbacks(testLayers(), 200_000) // > L1 UpThresh
ls := NewLayerSelector(1, 0, 0, 2, m.toCallbacks())
defer ls.Stop()
// L0 -> L1: should enter PROBING_UP within 150ms (one tick), then
// complete the 2s probe and switch to L1.
if !waitForLayer(ls, 1, 3*time.Second) {
t.Fatalf("L0->L1 upswitch timed out; currentLayer=%d selected=%d", ls.CurrentLayer(), m.currentSelected())
}
if got := m.currentSelected(); got != 1 {
t.Fatalf("after L1 upswitch, SetSelectedLayer was not called with 1 (got %d)", got)
}
if pli := m.pliCount.Load(); pli < 1 {
t.Fatalf("expected at least 1 PLI on layer switch, got %d", pli)
}
// L1 -> L2: raise BW above L2 UpThresh. Wait out the 5s cooldown and
// then the 2s probe (total ~7-8s).
m.setBW(1_500_000)
if !waitForLayer(ls, 2, 10*time.Second) {
t.Fatalf("L1->L2 upswitch timed out; currentLayer=%d", ls.CurrentLayer())
}
if got := m.currentSelected(); got != 2 {
t.Fatalf("after L2 upswitch, SetSelectedLayer was not called with 2 (got %d)", got)
}
}
// TestLayerSelectorDownswitch verifies L2 -> L1 -> L0 based on falling BW.
// Starts the selector pre-positioned at L2 by setting its state directly
// via `switchLayer`-equivalent initial-layer argument, then drives BW down.
//
// Thresholds:
//
// L2 DownThresh = 630 kbps → needs REMB < ~741 kbps
// L1 DownThresh = 77 kbps → needs REMB < ~91 kbps
//
// Downswitches are governed by a 500ms grace period, no cooldown, so this
// test runs in ~1.5 seconds.
func TestLayerSelectorDownswitch(t *testing.T) {
m := newMockCallbacks(testLayers(), 1_500_000) // high BW, at L2
ls := NewLayerSelector(1, 0, 2, 2, m.toCallbacks())
defer ls.Stop()
// Drop BW below L2 downswitch threshold. Effective = 500k * 0.85 = 425k
// is NOT below 630k effective threshold directly. Use 600k raw so
// effective = 510k, well below 630k.
m.setBW(600_000)
if !waitForLayer(ls, 1, 2*time.Second) {
t.Fatalf("L2->L1 downswitch timed out; currentLayer=%d", ls.CurrentLayer())
}
if got := m.currentSelected(); got != 1 {
t.Fatalf("after L1 downswitch, SetSelectedLayer was not called with 1 (got %d)", got)
}
// Drop below L1 downswitch threshold (77k effective → raw < 91k).
// Use 50k raw → effective 42k.
m.setBW(50_000)
if !waitForLayer(ls, 0, 2*time.Second) {
t.Fatalf("L1->L0 downswitch timed out; currentLayer=%d", ls.CurrentLayer())
}
if got := m.currentSelected(); got != 0 {
t.Fatalf("after L0 downswitch, SetSelectedLayer was not called with 0 (got %d)", got)
}
}
// TestLayerSelectorGraceDownRecovery verifies that a transient BW dip that
// recovers within the 500ms grace window does NOT cause a downswitch.
func TestLayerSelectorGraceDownRecovery(t *testing.T) {
m := newMockCallbacks(testLayers(), 1_500_000)
ls := NewLayerSelector(1, 0, 2, 2, m.toCallbacks())
defer ls.Stop()
// Dip below downthresh, then recover before grace expires.
m.setBW(500_000) // below L2 downthresh
time.Sleep(300 * time.Millisecond)
m.setBW(1_500_000) // recovered
time.Sleep(500 * time.Millisecond)
if got := ls.CurrentLayer(); got != 2 {
t.Fatalf("transient dip should not have downswitched; currentLayer=%d", got)
}
}
// TestLayerSelectorStaleBW verifies that with no REMB data (BW=-1), the
// state machine does not transition.
func TestLayerSelectorStaleBW(t *testing.T) {
m := newMockCallbacks(testLayers(), -1) // stale
ls := NewLayerSelector(1, 0, 1, 2, m.toCallbacks())
defer ls.Stop()
time.Sleep(1 * time.Second)
if got := ls.CurrentLayer(); got != 1 {
t.Fatalf("stale BW should not trigger a transition; currentLayer=%d", got)
}
}
// TestLayerSelectorOnMaxActiveLayerIncreasedWhenStale verifies the fallback
// path: when BW is stale (clients don't send REMB), discovery of a higher
// active layer from the sender causes an immediate upswitch.
func TestLayerSelectorOnMaxActiveLayerIncreasedWhenStale(t *testing.T) {
m := newMockCallbacks(testLayers(), -1)
ls := NewLayerSelector(1, 0, 1, 2, m.toCallbacks())
defer ls.Stop()
// Nothing has happened yet.
if got := ls.CurrentLayer(); got != 1 {
t.Fatalf("unexpected initial layer %d", got)
}
// Sender starts producing L2. With stale BW, we should upshift
// immediately up to the receiver's requested maxLayer.
ls.OnMaxActiveLayerIncreased(2)
if got := ls.CurrentLayer(); got != 2 {
t.Fatalf("expected upshift to L2 on maxActive increase with stale BW; got %d", got)
}
if got := m.currentSelected(); got != 2 {
t.Fatalf("SetSelectedLayer should have been called with 2; got %d", got)
}
}
// TestLayerSelectorOnMaxActiveLayerIncreasedWhenFresh verifies that when BW
// is fresh, OnMaxActiveLayerIncreased is a no-op — the state machine is in
// charge of layer selection.
func TestLayerSelectorOnMaxActiveLayerIncreasedWhenFresh(t *testing.T) {
m := newMockCallbacks(testLayers(), 200_000) // fresh, enough for L1 only
ls := NewLayerSelector(1, 0, 1, 2, m.toCallbacks())
defer ls.Stop()
ls.OnMaxActiveLayerIncreased(2)
if got := ls.CurrentLayer(); got != 1 {
t.Fatalf("fresh BW should leave state machine in charge; current=%d", got)
}
}
// TestLayerSelectorRespectsMaxLayer verifies that upswitches never exceed
// the receiver's requested maxLayer.
func TestLayerSelectorRespectsMaxLayer(t *testing.T) {
m := newMockCallbacks(testLayers(), 2_000_000) // way more than needed for L2
ls := NewLayerSelector(1, 0, 0, 1, m.toCallbacks())
defer ls.Stop()
// Wait long enough for an L0->L1 upswitch (~2.2s). Then wait past the
// cooldown (5s) plus another probe window (2s) to ensure the selector
// does NOT attempt to probe beyond maxLayer=1.
if !waitForLayer(ls, 1, 3*time.Second) {
t.Fatalf("L0->L1 upswitch timed out")
}
time.Sleep(8 * time.Second)
if got := ls.CurrentLayer(); got != 1 {
t.Fatalf("selector upshifted beyond maxLayer=1; got %d", got)
}
}
+27
View File
@@ -0,0 +1,27 @@
module github.com/nicegram/AltTgCalls/tools/go_sfu
go 1.24.2
require (
github.com/pion/datachannel v1.5.10
github.com/pion/dtls/v3 v3.0.6
github.com/pion/ice/v4 v4.0.7
github.com/pion/logging v0.2.3
github.com/pion/rtcp v1.2.15
github.com/pion/sctp v1.8.37
github.com/pion/srtp/v3 v3.0.5
)
require (
github.com/google/uuid v1.6.0 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/rtp v1.8.17 // indirect
github.com/pion/stun/v3 v3.0.0 // indirect
github.com/pion/transport/v3 v3.0.7 // indirect
github.com/pion/turn/v4 v4.0.0 // indirect
github.com/wlynxg/anet v0.0.3 // indirect
golang.org/x/crypto v0.32.0 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/sys v0.29.0 // indirect
)
+44
View File
@@ -0,0 +1,44 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o=
github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M=
github.com/pion/dtls/v3 v3.0.6 h1:7Hkd8WhAJNbRgq9RgdNh1aaWlZlGpYTzdqjy9x9sK2E=
github.com/pion/dtls/v3 v3.0.6/go.mod h1:iJxNQ3Uhn1NZWOMWlLxEEHAN5yX7GyPvvKw04v9bzYU=
github.com/pion/ice/v4 v4.0.7 h1:mnwuT3n3RE/9va41/9QJqN5+Bhc0H/x/ZyiVlWMw35M=
github.com/pion/ice/v4 v4.0.7/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw=
github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI=
github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90=
github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM=
github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo=
github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0=
github.com/pion/rtp v1.8.17 h1:CFhaPN8Ikt9Sk7B3pic0kfwVia2dUMEtPSL34Gvihjw=
github.com/pion/rtp v1.8.17/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk=
github.com/pion/sctp v1.8.37 h1:ZDmGPtRPX9mKCiVXtMbTWybFw3z/hVKAZgU81wcOrqs=
github.com/pion/sctp v1.8.37/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE=
github.com/pion/srtp/v3 v3.0.5 h1:8XLB6Dt3QXkMkRFpoqC3314BemkpMQK2mZeJc4pUKqo=
github.com/pion/srtp/v3 v3.0.5/go.mod h1:r1G7y5r1scZRLe2QJI/is+/O83W2d+JoEsuIexpw+uM=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM=
github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+239
View File
@@ -0,0 +1,239 @@
package main
import (
"fmt"
"io"
"net"
"sync"
"time"
)
const (
muxReadBufSize = 8192
muxChanBufSize = 256
)
// isDTLS returns true if the first byte indicates a DTLS record (RFC 7983: 2063).
func isDTLS(b byte) bool {
return b >= 20 && b <= 63
}
// isRTPOrRTCP returns true if the first byte indicates an RTP/RTCP packet (RFC 7983: 128191).
func isRTPOrRTCP(b byte) bool {
return b >= 128 && b <= 191
}
// isRTCP returns true if the packet is RTCP (not RTP) per RFC 5761 Section 4.
// RTCP packet types (byte[1]) are 200-211. RTP with Marker=1 and dynamic PT >= 96
// gives byte[1] >= 224, so we use byte[1] >= 200 && byte[1] < 224 to exclude RTP.
// In SRTCP the fixed header is unencrypted, so byte[1] is readable.
func isRTCP(pkt []byte) bool {
return len(pkt) >= 2 && pkt[1] >= 200 && pkt[1] < 224
}
// PacketDemux reads from a net.Conn and routes packets to separate DTLS,
// SRTP (RTP only), and RTCP channels based on RFC 7983 first-byte classification
// and RTP/RTCP payload type demux.
type PacketDemux struct {
conn net.Conn
dtlsCh chan []byte
srtpCh chan []byte
rtcpCh chan []byte
once sync.Once
closed chan struct{}
label string
}
func (d *PacketDemux) logf(format string, args ...interface{}) {
fmt.Printf("[demux-%s] %s\n", d.label, fmt.Sprintf(format, args...))
}
// NewPacketDemux creates a PacketDemux and starts the read loop goroutine.
func NewPacketDemux(conn net.Conn, label string) *PacketDemux {
d := &PacketDemux{
conn: conn,
dtlsCh: make(chan []byte, muxChanBufSize),
srtpCh: make(chan []byte, muxChanBufSize),
rtcpCh: make(chan []byte, muxChanBufSize),
closed: make(chan struct{}),
label: label,
}
go d.readLoop()
return d
}
func (d *PacketDemux) readLoop() {
buf := make([]byte, muxReadBufSize)
dtlsCount := 0
srtpCount := 0
rtcpCount := 0
otherCount := 0
for {
n, err := d.conn.Read(buf)
if err != nil {
d.Close()
return
}
if n == 0 {
continue
}
pkt := make([]byte, n)
copy(pkt, buf[:n])
switch {
case isDTLS(pkt[0]):
dtlsCount++
if dtlsCount <= 5 {
d.logf("DTLS packet #%d: %d bytes (first byte: 0x%02x)", dtlsCount, n, pkt[0])
}
select {
case d.dtlsCh <- pkt:
default:
d.logf("DTLS channel full, dropping packet")
}
case isRTPOrRTCP(pkt[0]):
if isRTCP(pkt) {
rtcpCount++
if rtcpCount <= 3 {
d.logf("RTCP packet #%d: %d bytes (type byte: 0x%02x)", rtcpCount, n, pkt[1])
}
select {
case d.rtcpCh <- pkt:
default:
// drop if channel full
}
} else {
srtpCount++
if srtpCount == 1 {
d.logf("First SRTP packet: %d bytes", n)
}
select {
case d.srtpCh <- pkt:
default:
// drop if channel full
}
}
default:
otherCount++
if otherCount <= 3 {
d.logf("Other packet: %d bytes (first byte: 0x%02x)", n, pkt[0])
}
}
}
}
// Close shuts down the demuxer and the underlying connection.
func (d *PacketDemux) Close() error {
var err error
d.once.Do(func() {
close(d.closed)
err = d.conn.Close()
})
return err
}
// DTLSEndpoint returns a net.Conn that yields only DTLS packets.
func (d *PacketDemux) DTLSEndpoint() net.Conn {
return &demuxEndpoint{demux: d, ch: d.dtlsCh}
}
// SRTPEndpoint returns a net.Conn that yields only SRTP (RTP) packets.
// RTCP packets are routed to RTCPChannel() instead.
func (d *PacketDemux) SRTPEndpoint() net.Conn {
return &demuxEndpoint{demux: d, ch: d.srtpCh}
}
// RTCPChannel returns a channel that receives raw encrypted SRTCP packets.
// These must be decrypted externally (not via SessionSRTP which only handles RTP).
func (d *PacketDemux) RTCPChannel() <-chan []byte {
return d.rtcpCh
}
// demuxEndpoint implements net.Conn for a single demux channel.
type demuxEndpoint struct {
demux *PacketDemux
ch chan []byte
mu sync.Mutex
leftover []byte
}
func (e *demuxEndpoint) Read(b []byte) (int, error) {
e.mu.Lock()
if len(e.leftover) > 0 {
n := copy(b, e.leftover)
e.leftover = e.leftover[n:]
if len(e.leftover) == 0 {
e.leftover = nil
}
e.mu.Unlock()
return n, nil
}
e.mu.Unlock()
select {
case <-e.demux.closed:
return 0, io.EOF
case pkt, ok := <-e.ch:
if !ok {
return 0, io.EOF
}
n := copy(b, pkt)
if n < len(pkt) {
e.mu.Lock()
e.leftover = pkt[n:]
e.mu.Unlock()
}
return n, nil
}
}
func (e *demuxEndpoint) Write(b []byte) (int, error) {
return e.demux.conn.Write(b)
}
func (e *demuxEndpoint) Close() error {
return e.demux.Close()
}
func (e *demuxEndpoint) LocalAddr() net.Addr {
return e.demux.conn.LocalAddr()
}
func (e *demuxEndpoint) RemoteAddr() net.Addr {
return e.demux.conn.RemoteAddr()
}
func (e *demuxEndpoint) SetDeadline(t time.Time) error {
return e.demux.conn.SetDeadline(t)
}
func (e *demuxEndpoint) SetReadDeadline(t time.Time) error {
return e.demux.conn.SetReadDeadline(t)
}
func (e *demuxEndpoint) SetWriteDeadline(t time.Time) error {
return e.demux.conn.SetWriteDeadline(t)
}
// connToPacketConn wraps a net.Conn into a net.PacketConn.
// It is used to adapt a demuxEndpoint for pion/dtls.Server(), which
// requires net.PacketConn. Since the endpoint is already bound to a
// single peer, ReadFrom returns the conn's RemoteAddr and WriteTo ignores
// the addr parameter.
type connToPacketConn struct {
net.Conn
}
// WrapAsPacketConn adapts a net.Conn to net.PacketConn.
func WrapAsPacketConn(c net.Conn) net.PacketConn {
return &connToPacketConn{Conn: c}
}
func (c *connToPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, err := c.Conn.Read(b)
return n, c.Conn.RemoteAddr(), err
}
func (c *connToPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) {
return c.Conn.Write(b)
}
+128
View File
@@ -0,0 +1,128 @@
package main
import (
"math/rand"
"sync"
"time"
)
// NetworkSimulator models a uni-directional network pipe with delay, jitter,
// packet loss, and bandwidth cap (token bucket).
type NetworkSimulator struct {
mu sync.Mutex
delayMs int
jitterMs int
dropRate float64
bandwidthBps int64
// Token bucket for bandwidth cap.
tokens float64 // available tokens (bits)
maxTokens float64 // max tokens = 200ms worth of bandwidth
lastRefill time.Time
rng *rand.Rand
closed bool
}
// NewNetworkSimulator creates a simulator with no simulation (passthrough).
func NewNetworkSimulator() *NetworkSimulator {
return &NetworkSimulator{
lastRefill: time.Now(),
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// SetParams reconfigures the simulator at runtime. Thread-safe.
func (ns *NetworkSimulator) SetParams(delayMs, jitterMs int, dropRate float64, bandwidthBps int64) {
ns.mu.Lock()
defer ns.mu.Unlock()
ns.delayMs = delayMs
ns.jitterMs = jitterMs
ns.dropRate = dropRate
ns.bandwidthBps = bandwidthBps
if bandwidthBps > 0 {
ns.maxTokens = float64(bandwidthBps) * 0.2 // 200ms buffer
if ns.tokens > ns.maxTokens {
ns.tokens = ns.maxTokens
}
} else {
ns.maxTokens = 0
ns.tokens = 0
}
}
// Send processes a packet through the simulator. deliverFn is called
// (possibly asynchronously) after simulation. The packet bytes are copied
// if delivery is deferred.
func (ns *NetworkSimulator) Send(pkt []byte, deliverFn func([]byte)) {
ns.mu.Lock()
if ns.closed {
ns.mu.Unlock()
return
}
// Drop check.
if ns.dropRate > 0 && ns.rng.Float64() < ns.dropRate {
ns.mu.Unlock()
return
}
// Bandwidth cap: token bucket.
if ns.bandwidthBps > 0 {
ns.refillTokens()
cost := float64(len(pkt)) * 8
if ns.tokens < cost {
// Queue full / no tokens — tail drop.
ns.mu.Unlock()
return
}
ns.tokens -= cost
}
// Calculate delay.
delayMs := ns.delayMs
if ns.jitterMs > 0 {
delayMs += ns.rng.Intn(2*ns.jitterMs+1) - ns.jitterMs
if delayMs < 0 {
delayMs = 0
}
}
ns.mu.Unlock()
if delayMs == 0 {
deliverFn(pkt)
return
}
// Copy packet for deferred delivery.
pktCopy := make([]byte, len(pkt))
copy(pktCopy, pkt)
time.AfterFunc(time.Duration(delayMs)*time.Millisecond, func() {
deliverFn(pktCopy)
})
}
// Close stops the simulator. Pending delayed packets may still fire.
func (ns *NetworkSimulator) Close() {
ns.mu.Lock()
ns.closed = true
ns.mu.Unlock()
}
// refillTokens adds tokens based on elapsed time. Must be called with mu held.
func (ns *NetworkSimulator) refillTokens() {
now := time.Now()
elapsed := now.Sub(ns.lastRefill).Seconds()
ns.lastRefill = now
ns.tokens += float64(ns.bandwidthBps) * elapsed
if ns.tokens > ns.maxTokens {
ns.tokens = ns.maxTokens
}
}
// IsPassthrough returns true if no simulation is configured.
func (ns *NetworkSimulator) IsPassthrough() bool {
ns.mu.Lock()
defer ns.mu.Unlock()
return ns.delayMs == 0 && ns.jitterMs == 0 && ns.dropRate == 0 && ns.bandwidthBps == 0
}
+637
View File
@@ -0,0 +1,637 @@
package main
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"fmt"
"math/big"
"net"
"sync"
"time"
"github.com/pion/datachannel"
"github.com/pion/dtls/v3"
"github.com/pion/ice/v4"
"github.com/pion/logging"
"github.com/pion/rtcp"
"github.com/pion/sctp"
"github.com/pion/srtp/v3"
)
// ParticipantConfig holds the client's join parameters extracted from the join payload.
type ParticipantConfig struct {
AudioSSRC uint32
Ufrag string
Pwd string
Fingerprint string // SHA-256, colon-separated uppercase hex (e.g., "AB:CD:EF:...")
}
// Participant holds the per-participant transport stack: ICE → DTLS → SRTP + SCTP/DataChannel.
type Participant struct {
ID int
AudioSSRC uint32
iceAgent *ice.Agent
iceConn *ice.Conn
demux *PacketDemux
dtlsConn *dtls.Conn
srtpSession *srtp.SessionSRTP
srtpWriter *srtp.WriteStreamSRTP
srtpProfile srtp.ProtectionProfile
srtpKeys srtp.SessionKeys // saved for creating SRTCP contexts
// Separate SRTCP contexts for manual RTCP decrypt/encrypt.
// These are independent from the SessionSRTP used for RTP.
srtcpRemoteCtx *srtp.Context // decrypt SRTCP received from this participant
srtcpLocalCtx *srtp.Context // encrypt SRTCP sent to this participant
srtcpMu sync.Mutex // protects srtcpLocalCtx (single-writer)
sctpAssoc *sctp.Association
dataChannel *datachannel.DataChannel
tlsCert tls.Certificate
fingerprint string // SHA-256, colon-separated uppercase hex
localUfrag string
localPwd string
loggerFactory logging.LoggerFactory
log logging.LeveledLogger
// Video layer selection: receiver requests which layer to receive from each sender.
videoLayerMu sync.RWMutex
requestedLayers map[int]int // senderID -> layer index
// Bandwidth estimation from REMB.
bwEstimator *BandwidthEstimator
// Selected layers: what the SFU actually forwards (set by LayerSelector).
selectedLayerMu sync.RWMutex
selectedLayers map[int]int // senderID -> layer index
onColibriMessage func(participantID int, msg string) // set before Connect(), read from acceptDataChannel goroutine
// RTCP feedback callback: called when PLI or FIR is received from this participant.
// mediaSSRC is the SSRC the receiver wants a keyframe for.
onRTCPFeedback func(participantID int, mediaSSRC uint32, isFIR bool)
// Network simulation (delay/jitter/loss/bandwidth cap per direction).
ingressSim *NetworkSimulator
egressSim *NetworkSimulator
closed chan struct{}
once sync.Once
}
// NewParticipant creates a new Participant with an ICE agent and self-signed certificate.
// It does NOT start ICE gathering or connection — call GatherCandidates() and Connect() for that.
func NewParticipant(id int, config ParticipantConfig, loggerFactory logging.LoggerFactory) (*Participant, error) {
log := loggerFactory.NewLogger(fmt.Sprintf("participant-%d", id))
// Generate self-signed ECDSA P-256 certificate.
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate ECDSA key: %w", err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey)
if err != nil {
return nil, fmt.Errorf("create certificate: %w", err)
}
tlsCert := tls.Certificate{
Certificate: [][]byte{certDER},
PrivateKey: privKey,
}
// Compute SHA-256 fingerprint of the DER certificate.
hash := sha256.Sum256(certDER)
fingerprint := formatFingerprint(hash[:])
// Create ICE agent — UDP, host candidates, ICE-lite.
// The tgcalls GroupNetworkManager hardcodes ICEROLE_CONTROLLED for the client,
// so the SFU must be the controlling side (use Dial, not Accept).
// ICE-lite: the SFU passively accepts incoming connectivity checks.
// No remote candidates needed: when the client's STUN binding requests arrive,
// pion creates peer-reflexive candidates automatically.
agent, err := ice.NewAgent(&ice.AgentConfig{
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4},
CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost},
Lite: true,
IncludeLoopback: true,
IPFilter: func(ip net.IP) bool {
return true // accept all interfaces, including loopback
},
LoggerFactory: loggerFactory,
})
if err != nil {
return nil, fmt.Errorf("create ICE agent: %w", err)
}
localUfrag, localPwd, err := agent.GetLocalUserCredentials()
if err != nil {
_ = agent.Close()
return nil, fmt.Errorf("get local credentials: %w", err)
}
log.Infof("Created participant %d (SSRC=%d, ufrag=%s, fingerprint=%s)", id, config.AudioSSRC, localUfrag, fingerprint)
return &Participant{
ID: id,
AudioSSRC: config.AudioSSRC,
iceAgent: agent,
tlsCert: tlsCert,
fingerprint: fingerprint,
localUfrag: localUfrag,
localPwd: localPwd,
loggerFactory: loggerFactory,
log: log,
requestedLayers: make(map[int]int),
bwEstimator: &BandwidthEstimator{},
selectedLayers: make(map[int]int),
ingressSim: NewNetworkSimulator(),
egressSim: NewNetworkSimulator(),
closed: make(chan struct{}),
}, nil
}
// Fingerprint returns the SHA-256 fingerprint of the participant's DTLS certificate.
func (p *Participant) Fingerprint() string {
return p.fingerprint
}
// LocalUfrag returns the local ICE username fragment.
func (p *Participant) LocalUfrag() string {
return p.localUfrag
}
// LocalPwd returns the local ICE password.
func (p *Participant) LocalPwd() string {
return p.localPwd
}
// GatherCandidates triggers ICE gathering and waits for completion.
// Returns the gathered ICE candidates.
func (p *Participant) GatherCandidates() ([]ice.Candidate, error) {
var (
candidates []ice.Candidate
mu sync.Mutex
done = make(chan struct{})
)
if err := p.iceAgent.OnCandidate(func(c ice.Candidate) {
if c == nil {
// nil candidate signals gathering complete.
close(done)
return
}
mu.Lock()
candidates = append(candidates, c)
mu.Unlock()
}); err != nil {
return nil, fmt.Errorf("set OnCandidate: %w", err)
}
if err := p.iceAgent.GatherCandidates(); err != nil {
return nil, fmt.Errorf("gather candidates: %w", err)
}
<-done
mu.Lock()
defer mu.Unlock()
p.log.Infof("Gathered %d ICE candidates", len(candidates))
return candidates, nil
}
// Connect establishes the full transport stack: ICE → DTLS → SRTP + SCTP.
// The SFU is DTLS client (active). tgcalls GroupNetworkManager hardcodes SSL_SERVER.
//
// iceControlling selects the ICE role:
// - true (Dial): SFU is controlling. Required for tgcalls GroupNetworkManager which
// hardcodes ICEROLE_CONTROLLED (non-standard).
// - false (Accept): SFU is controlled (standard for ICE-lite). Required for PeerConnection
// clients that follow RFC 8445 (full agent = controlling when remote is ice-lite).
func (p *Participant) Connect(ctx context.Context, remoteUfrag, remotePwd string, iceControlling bool) error {
// 1. ICE connection.
var iceConn *ice.Conn
var err error
if iceControlling {
iceConn, err = p.iceAgent.Dial(ctx, remoteUfrag, remotePwd)
} else {
iceConn, err = p.iceAgent.Accept(ctx, remoteUfrag, remotePwd)
}
if err != nil {
return fmt.Errorf("ICE dial: %w", err)
}
p.iceConn = iceConn
p.log.Infof("ICE connected")
// 2. Demux: split DTLS and SRTP traffic.
p.demux = NewPacketDemux(iceConn, fmt.Sprintf("p%d", p.ID))
// 3. DTLS: client-side handshake over the DTLS endpoint.
// tgcalls GroupNetworkManager hardcodes SetDtlsRole(SSL_SERVER), so the SFU must be the DTLS client.
dtlsEndpoint := p.demux.DTLSEndpoint()
remoteAddr := dtlsEndpoint.RemoteAddr()
packetConn := WrapAsPacketConn(dtlsEndpoint)
dtlsConn, err := dtls.Client(packetConn, remoteAddr, &dtls.Config{
Certificates: []tls.Certificate{p.tlsCert},
// Offer GCM profiles matching tgcalls GroupNetworkManager::getDefaulCryptoOptions()
// which enables enable_gcm_crypto_suites=true and disables AES-128-CM-SHA1-80.
SRTPProtectionProfiles: []dtls.SRTPProtectionProfile{
dtls.SRTP_AEAD_AES_256_GCM,
dtls.SRTP_AEAD_AES_128_GCM,
},
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
InsecureSkipVerify: true, // tgcalls verifies fingerprint out-of-band; we skip TLS chain verification
LoggerFactory: p.loggerFactory,
})
if err != nil {
p.demux.Close()
return fmt.Errorf("DTLS create: %w", err)
}
p.dtlsConn = dtlsConn
// dtls.Client() is lazy; explicitly run the handshake before accessing ConnectionState.
if err := dtlsConn.HandshakeContext(ctx); err != nil {
p.demux.Close()
return fmt.Errorf("DTLS handshake: %w", err)
}
p.log.Infof("DTLS connected")
// 4. Extract SRTP keying material from DTLS.
state, ok := dtlsConn.ConnectionState()
if !ok {
return fmt.Errorf("DTLS connection state not available")
}
// Map the negotiated DTLS-SRTP protection profile to a pion/srtp ProtectionProfile.
negotiatedProfile, profileOk := dtlsConn.SelectedSRTPProtectionProfile()
if !profileOk {
p.demux.Close()
return fmt.Errorf("no SRTP protection profile negotiated")
}
var srtpProfile srtp.ProtectionProfile
switch negotiatedProfile {
case dtls.SRTP_AEAD_AES_256_GCM:
srtpProfile = srtp.ProtectionProfileAeadAes256Gcm
case dtls.SRTP_AEAD_AES_128_GCM:
srtpProfile = srtp.ProtectionProfileAeadAes128Gcm
case dtls.SRTP_AES128_CM_HMAC_SHA1_80:
srtpProfile = srtp.ProtectionProfileAes128CmHmacSha1_80
case dtls.SRTP_AES128_CM_HMAC_SHA1_32:
srtpProfile = srtp.ProtectionProfileAes128CmHmacSha1_32
default:
p.demux.Close()
return fmt.Errorf("unsupported SRTP protection profile: 0x%04x", negotiatedProfile)
}
p.log.Infof("Negotiated SRTP profile: 0x%04x", negotiatedProfile)
srtpConfig := &srtp.Config{
Profile: srtpProfile,
}
// SFU is DTLS client → isClient=true
if err := srtpConfig.ExtractSessionKeysFromDTLS(&state, true); err != nil {
return fmt.Errorf("extract SRTP keys: %w", err)
}
// Save keys and profile for creating SRTCP contexts.
p.srtpProfile = srtpProfile
p.srtpKeys = srtpConfig.Keys
// 5. SRTP session over the SRTP endpoint (RTP only — RTCP is handled separately).
srtpEndpoint := p.demux.SRTPEndpoint()
srtpSession, err := srtp.NewSessionSRTP(srtpEndpoint, srtpConfig)
if err != nil {
return fmt.Errorf("create SRTP session: %w", err)
}
p.srtpSession = srtpSession
srtpWriter, err := srtpSession.OpenWriteStream()
if err != nil {
return fmt.Errorf("open SRTP write stream: %w", err)
}
p.srtpWriter = srtpWriter
p.log.Infof("SRTP session established")
// 5b. Create separate SRTCP contexts for manual RTCP handling.
// Remote context: decrypt SRTCP received from this participant (their local = our remote).
p.srtcpRemoteCtx, err = srtp.CreateContext(
p.srtpKeys.RemoteMasterKey, p.srtpKeys.RemoteMasterSalt, p.srtpProfile,
)
if err != nil {
return fmt.Errorf("create SRTCP remote context: %w", err)
}
// Local context: encrypt SRTCP we send to this participant (our local keys).
p.srtcpLocalCtx, err = srtp.CreateContext(
p.srtpKeys.LocalMasterKey, p.srtpKeys.LocalMasterSalt, p.srtpProfile,
)
if err != nil {
return fmt.Errorf("create SRTCP local context: %w", err)
}
p.log.Infof("SRTCP contexts created")
// 5c. Start RTCP read loop.
go p.readRTCPLoop()
// 6. SCTP association over DTLS.
sctpAssoc, err := sctp.Server(sctp.Config{
NetConn: dtlsConn,
LoggerFactory: p.loggerFactory,
})
if err != nil {
return fmt.Errorf("create SCTP association: %w", err)
}
p.sctpAssoc = sctpAssoc
p.log.Infof("SCTP association established")
// 7. Start goroutine to accept data channels.
go p.acceptDataChannel()
return nil
}
// acceptDataChannel waits for the client to open a data channel and reads Colibri messages.
func (p *Participant) acceptDataChannel() {
dc, err := datachannel.Accept(p.sctpAssoc, &datachannel.Config{
LoggerFactory: p.loggerFactory,
})
if err != nil {
select {
case <-p.closed:
return // Expected during shutdown.
default:
p.log.Warnf("Accept data channel: %v", err)
return
}
}
p.dataChannel = dc
p.log.Infof("Data channel accepted")
buf := make([]byte, 4096)
for {
n, isString, err := dc.ReadDataChannel(buf)
if err != nil {
select {
case <-p.closed:
return
default:
p.log.Debugf("Data channel read error: %v", err)
return
}
}
if isString {
msg := string(buf[:n])
p.log.Debugf("Colibri message: %s", msg)
if p.onColibriMessage != nil {
p.onColibriMessage(p.ID, msg)
}
} else {
p.log.Debugf("Data channel binary message (%d bytes)", n)
}
}
}
// SetColibriCallback sets the callback for incoming Colibri data channel messages.
func (p *Participant) SetColibriCallback(cb func(participantID int, msg string)) {
p.onColibriMessage = cb
}
// SetRTCPFeedbackCallback sets the callback for PLI/FIR RTCP feedback from this participant.
func (p *Participant) SetRTCPFeedbackCallback(cb func(participantID int, mediaSSRC uint32, isFIR bool)) {
p.onRTCPFeedback = cb
}
// readRTCPLoop reads encrypted SRTCP packets from the demux RTCP channel,
// decrypts them, parses for PLI/FIR, and invokes the feedback callback.
func (p *Participant) readRTCPLoop() {
rtcpCh := p.demux.RTCPChannel()
decryptBuf := make([]byte, 8192)
pktCount := 0
for {
select {
case <-p.closed:
return
case encrypted, ok := <-rtcpCh:
if !ok {
return
}
// Decrypt SRTCP.
decrypted, err := p.srtcpRemoteCtx.DecryptRTCP(decryptBuf[:0], encrypted, nil)
if err != nil {
pktCount++
if pktCount <= 5 {
p.log.Debugf("SRTCP decrypt error: %v", err)
}
continue
}
// Parse RTCP compound packet.
packets, err := rtcp.Unmarshal(decrypted)
if err != nil {
p.log.Debugf("RTCP unmarshal error: %v", err)
continue
}
for _, pkt := range packets {
switch fb := pkt.(type) {
case *rtcp.PictureLossIndication:
p.log.Infof("Received PLI from participant %d for MediaSSRC=%d", p.ID, fb.MediaSSRC)
if p.onRTCPFeedback != nil {
p.onRTCPFeedback(p.ID, fb.MediaSSRC, false)
}
case *rtcp.FullIntraRequest:
for _, entry := range fb.FIR {
p.log.Infof("Received FIR from participant %d for SSRC=%d", p.ID, entry.SSRC)
if p.onRTCPFeedback != nil {
p.onRTCPFeedback(p.ID, entry.SSRC, true)
}
}
case *rtcp.ReceiverEstimatedMaximumBitrate:
bps := float64(fb.Bitrate)
p.bwEstimator.OnREMB(bps)
p.log.Debugf("REMB from participant %d: %.0f bps (smoothed=%.0f, effective=%.0f)",
p.ID, bps, p.bwEstimator.SmoothedBps(), p.bwEstimator.EffectiveBps())
}
}
}
}
}
// WriteRTCP sends a plaintext RTCP packet to this participant, encrypting it with
// the local SRTCP context and writing directly to the ICE connection.
func (p *Participant) WriteRTCP(data []byte) error {
if p.srtcpLocalCtx == nil || p.iceConn == nil {
return fmt.Errorf("SRTCP context or ICE conn not established")
}
if p.egressSim.IsPassthrough() {
return p.writeRTCPDirect(data)
}
p.egressSim.Send(data, func(delayed []byte) {
p.writeRTCPDirect(delayed)
})
return nil
}
func (p *Participant) writeRTCPDirect(data []byte) error {
p.srtcpMu.Lock()
encrypted, err := p.srtcpLocalCtx.EncryptRTCP(nil, data, nil)
p.srtcpMu.Unlock()
if err != nil {
return fmt.Errorf("encrypt SRTCP: %w", err)
}
_, err = p.iceConn.Write(encrypted)
return err
}
// SetRequestedLayer sets the video layer this receiver wants from a given sender.
func (p *Participant) SetRequestedLayer(senderID int, layer int) {
p.videoLayerMu.Lock()
p.requestedLayers[senderID] = layer
p.videoLayerMu.Unlock()
}
// GetRequestedLayer returns the video layer this receiver wants from a given sender.
// Returns -1 if no layer is requested (meaning: don't forward video from this sender).
func (p *Participant) GetRequestedLayer(senderID int) int {
p.videoLayerMu.RLock()
defer p.videoLayerMu.RUnlock()
if layer, ok := p.requestedLayers[senderID]; ok {
return layer
}
return -1
}
// SetSelectedLayer sets the video layer the SFU actually forwards from a given sender to this receiver.
func (p *Participant) SetSelectedLayer(senderID int, layer int) {
p.selectedLayerMu.Lock()
p.selectedLayers[senderID] = layer
p.selectedLayerMu.Unlock()
}
// GetSelectedLayer returns the video layer the SFU forwards from a given sender to this receiver.
// Returns -1 if no layer is selected (don't forward).
func (p *Participant) GetSelectedLayer(senderID int) int {
p.selectedLayerMu.RLock()
defer p.selectedLayerMu.RUnlock()
if layer, ok := p.selectedLayers[senderID]; ok {
return layer
}
return -1
}
// SendText sends a UTF-8 string message over the data channel.
// Returns an error if the data channel is not yet established.
func (p *Participant) SendText(msg string) error {
dc := p.dataChannel
if dc == nil {
return fmt.Errorf("data channel not established")
}
_, err := dc.WriteDataChannel([]byte(msg), true)
return err
}
// WriteRTP sends an encrypted RTP packet to this participant via the SRTP write stream.
func (p *Participant) WriteRTP(pkt []byte) (int, error) {
if p.srtpWriter == nil {
return 0, fmt.Errorf("SRTP session not established")
}
if p.egressSim.IsPassthrough() {
return p.srtpWriter.Write(pkt)
}
var n int
var writeErr error
p.egressSim.Send(pkt, func(delayed []byte) {
n, writeErr = p.srtpWriter.Write(delayed)
})
return n, writeErr
}
// AcceptStream blocks until a new SRTP read stream appears (new SSRC from client).
// Returns the read stream and its SSRC.
func (p *Participant) AcceptStream() (*srtp.ReadStreamSRTP, uint32, error) {
if p.srtpSession == nil {
return nil, 0, fmt.Errorf("SRTP session not established")
}
return p.srtpSession.AcceptStream()
}
// Close tears down all transport layers in order.
func (p *Participant) Close() error {
var firstErr error
p.once.Do(func() {
close(p.closed)
if p.ingressSim != nil {
p.ingressSim.Close()
}
if p.egressSim != nil {
p.egressSim.Close()
}
if p.dataChannel != nil {
if err := p.dataChannel.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if p.sctpAssoc != nil {
if err := p.sctpAssoc.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if p.srtpSession != nil {
if err := p.srtpSession.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if p.dtlsConn != nil {
if err := p.dtlsConn.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if p.demux != nil {
if err := p.demux.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if p.iceConn != nil {
if err := p.iceConn.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if p.iceAgent != nil {
if err := p.iceAgent.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
p.log.Infof("Participant %d closed", p.ID)
})
return firstErr
}
// formatFingerprint converts a hash byte slice to colon-separated uppercase hex.
func formatFingerprint(hash []byte) string {
result := make([]byte, 0, len(hash)*3-1)
for i, b := range hash {
if i > 0 {
result = append(result, ':')
}
result = append(result, fmt.Sprintf("%02X", b)...)
}
return string(result)
}
+1240
View File
File diff suppressed because it is too large Load Diff
+262
View File
@@ -0,0 +1,262 @@
package main
import (
"encoding/binary"
"sync"
"time"
"github.com/pion/rtcp"
)
// --- RTP Header Extension Parsing ---
// parseTWCCSeq extracts the transport-wide sequence number from an RTP packet.
// extID is the header extension ID to look for (typically 3).
// Returns the sequence number and true if found, or 0 and false.
func parseTWCCSeq(pkt []byte, extID int) (uint16, bool) {
if len(pkt) < 12 {
return 0, false
}
// Check extension bit (X) in RTP header byte 0.
if pkt[0]&0x10 == 0 {
return 0, false
}
// Skip fixed header (12 bytes) + CSRC list.
cc := int(pkt[0] & 0x0F)
offset := 12 + cc*4
if offset+4 > len(pkt) {
return 0, false
}
// Check for one-byte header extension (0xBEDE magic).
if pkt[offset] != 0xBE || pkt[offset+1] != 0xDE {
return 0, false
}
// Extension length in 32-bit words.
extLen := int(binary.BigEndian.Uint16(pkt[offset+2:])) * 4
offset += 4
extEnd := offset + extLen
if extEnd > len(pkt) {
return 0, false
}
// Scan extension elements: [id:4][len:4][data...].
for offset < extEnd {
b := pkt[offset]
if b == 0 {
// Padding byte.
offset++
continue
}
id := int(b >> 4)
dataLen := int(b&0x0F) + 1 // len field is 0-based
offset++
if id == extID && dataLen >= 2 && offset+2 <= extEnd {
seq := binary.BigEndian.Uint16(pkt[offset:])
return seq, true
}
offset += dataLen
}
return 0, false
}
// --- Transport-CC Feedback Generator ---
type twccArrival struct {
seq uint16
arrivalUs int64 // microseconds since generator creation
}
// TransportCCGenerator generates RTCP transport-cc feedback for a sender.
// It tracks packet arrivals and emits feedback every 100ms.
type TransportCCGenerator struct {
mu sync.Mutex
arrivals []twccArrival
startTime time.Time
fbCount uint8 // feedback packet counter
// Callback to send the feedback RTCP packet.
sendFeedback func(data []byte)
stopCh chan struct{}
done chan struct{}
}
// NewTransportCCGenerator creates and starts a generator.
// sendFeedback is called with marshalled+encrypted RTCP data to send to the sender.
func NewTransportCCGenerator(sendFeedback func(data []byte)) *TransportCCGenerator {
g := &TransportCCGenerator{
startTime: time.Now(),
sendFeedback: sendFeedback,
stopCh: make(chan struct{}),
done: make(chan struct{}),
}
go g.run()
return g
}
// RecordArrival records a packet arrival. Thread-safe.
func (g *TransportCCGenerator) RecordArrival(twccSeq uint16) {
g.mu.Lock()
defer g.mu.Unlock()
arrivalUs := time.Since(g.startTime).Microseconds()
g.arrivals = append(g.arrivals, twccArrival{seq: twccSeq, arrivalUs: arrivalUs})
}
// Stop terminates the generator.
func (g *TransportCCGenerator) Stop() {
close(g.stopCh)
<-g.done
}
func (g *TransportCCGenerator) run() {
defer close(g.done)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-g.stopCh:
return
case <-ticker.C:
g.emitFeedback()
}
}
}
func (g *TransportCCGenerator) emitFeedback() {
g.mu.Lock()
if len(g.arrivals) == 0 {
g.mu.Unlock()
return
}
// Take all arrivals.
arrivals := g.arrivals
g.arrivals = nil
g.fbCount++
fbCount := g.fbCount
g.mu.Unlock()
// Sort by sequence number (should already be mostly sorted).
for i := 1; i < len(arrivals); i++ {
for j := i; j > 0 && seqBefore(arrivals[j].seq, arrivals[j-1].seq); j-- {
arrivals[j], arrivals[j-1] = arrivals[j-1], arrivals[j]
}
}
baseSeq := arrivals[0].seq
// Number of sequence numbers covered (including gaps).
lastSeq := arrivals[len(arrivals)-1].seq
packetCount := seqDiff(baseSeq, lastSeq) + 1
// Reference time: arrival of first packet in 64ms units.
refTimeUs := arrivals[0].arrivalUs
refTime := uint32(refTimeUs / 64000) // 64ms units, 24-bit in spec but stored as uint32
// Build received set for gap detection.
receivedAt := make(map[uint16]int64, len(arrivals))
for _, a := range arrivals {
receivedAt[a.seq] = a.arrivalUs
}
// Build packet chunks and recv deltas.
var chunks []rtcp.PacketStatusChunk
var deltas []*rtcp.RecvDelta
// Process in runs of up to 7 (status vector chunk capacity for 2-bit symbols).
prevArrivalUs := refTimeUs
var statusList []uint16
seq := baseSeq
for i := 0; i < int(packetCount); i++ {
arrUs, received := receivedAt[seq]
if received {
deltaUs := arrUs - prevArrivalUs
if deltaUs >= 0 && deltaUs <= 63750 { // fits in small delta (0-255 * 250us)
statusList = append(statusList, rtcp.TypeTCCPacketReceivedSmallDelta)
deltas = append(deltas, &rtcp.RecvDelta{
Type: rtcp.TypeTCCPacketReceivedSmallDelta,
Delta: deltaUs,
})
} else {
statusList = append(statusList, rtcp.TypeTCCPacketReceivedLargeDelta)
deltas = append(deltas, &rtcp.RecvDelta{
Type: rtcp.TypeTCCPacketReceivedLargeDelta,
Delta: deltaUs,
})
}
prevArrivalUs = arrUs
} else {
statusList = append(statusList, rtcp.TypeTCCPacketNotReceived)
}
seq++
}
// Encode status list as status vector chunks (7 symbols per chunk with 2-bit symbols).
for i := 0; i < len(statusList); i += 7 {
end := i + 7
if end > len(statusList) {
end = len(statusList)
}
chunk := statusList[i:end]
// Check if all same status (use run-length).
allSame := true
for _, s := range chunk {
if s != chunk[0] {
allSame = false
break
}
}
if allSame && len(chunk) >= 2 {
chunks = append(chunks, &rtcp.RunLengthChunk{
Type: rtcp.TypeTCCRunLengthChunk,
PacketStatusSymbol: chunk[0],
RunLength: uint16(len(chunk)),
})
} else {
// Status vector with 2-bit symbols.
symbolList := make([]uint16, len(chunk))
copy(symbolList, chunk)
chunks = append(chunks, &rtcp.StatusVectorChunk{
Type: rtcp.TypeTCCStatusVectorChunk,
SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit,
SymbolList: symbolList,
})
}
}
fb := &rtcp.TransportLayerCC{
SenderSSRC: 1,
MediaSSRC: 0,
BaseSequenceNumber: baseSeq,
PacketStatusCount: packetCount,
ReferenceTime: refTime,
FbPktCount: fbCount,
PacketChunks: chunks,
RecvDeltas: deltas,
}
data, err := rtcp.Marshal([]rtcp.Packet{fb})
if err != nil {
return
}
g.sendFeedback(data)
}
// seqBefore returns true if a comes before b in the uint16 sequence space.
func seqBefore(a, b uint16) bool {
return int16(a-b) < 0
}
// seqDiff returns the forward distance from a to b in uint16 sequence space.
func seqDiff(a, b uint16) uint16 {
return b - a
}