diff --git a/rtsp/RTSPServer.go b/rtsp/RTSPServer.go index 4b9cd39..41f7909 100644 --- a/rtsp/RTSPServer.go +++ b/rtsp/RTSPServer.go @@ -1,10 +1,11 @@ package rtsp import ( + "bytes" "context" "log" - "sync" "strconv" + "sync" "github.com/bluenviron/gortsplib/v5" "github.com/bluenviron/gortsplib/v5/pkg/base" @@ -44,15 +45,16 @@ type Server struct { wg sync.WaitGroup } -func CreateServer(ctx context.Context, host string, port int, relay *libipcamera.RTPRelay) *Server { +func CreateServer(ctx context.Context, ip string, port int, relay *libipcamera.RTPRelay) *Server { cctx, cancel := context.WithCancel(ctx) h := &Handler{} srv := &Server{cancel: cancel} + // create RTSP server h.server = &gortsplib.Server{ Handler: h, - RTSPAddress: host + ":" + strconv.Itoa(port), + RTSPAddress: ip + ":" + strconv.Itoa(port), UDPRTPAddress: ":8000", UDPRTCPAddress: ":8001", } @@ -61,6 +63,7 @@ func CreateServer(ctx context.Context, host string, port int, relay *libipcamera panic(err) } + // create SDP + stream h.mu.Lock() desc := &description.Session{ Medias: []*description.Media{{ @@ -79,41 +82,137 @@ func CreateServer(ctx context.Context, host string, port int, relay *libipcamera srv.stream = h.stream h.mu.Unlock() - // Pump frames -> RTP - srv.wg.Add(1) + // start streaming goroutine + srv.startPump(cctx, relay, h) + + log.Printf("RTSP server ready: rtsp://%s:%d/", ip, port) + return srv +} + +// ---- STREAMING / H264 → RTP ---- + +func (s *Server) startPump(cctx context.Context, relay *libipcamera.RTPRelay, h *Handler) { + s.wg.Add(1) go func() { - defer srv.wg.Done() + defer s.wg.Done() + var seq uint16 var ts uint32 + + var sps, pps []byte + for { select { case <-cctx.Done(): return + case frame, ok := <-relay.Frames: if !ok { return } - pkt := &rtp.Packet{ - Header: rtp.Header{ - Version: 2, - Marker: true, - PayloadType: 96, - SequenceNumber: seq, - Timestamp: ts, - }, - Payload: frame.Data, + + nalus := splitNALUnits(frame.Data) + for _, nalu := range nalus { + if len(nalu) < 1 { + continue + } + nalType := nalu[0] & 0x1F + + switch nalType { + case 7: // SPS + sps = append([]byte{}, nalu...) + continue + case 8: // PPS + pps = append([]byte{}, nalu...) + continue + } + + if nalType == 5 && sps != nil && pps != nil { + sendNALUnit(h, s.stream, sps, &seq, &ts) + sendNALUnit(h, s.stream, pps, &seq, &ts) + } + sendNALUnit(h, s.stream, nalu, &seq, &ts) } - seq++ - ts += 3600 - h.stream.WritePacketRTP(h.stream.Desc.Medias[0], pkt) } } }() - - log.Printf("RTSP server ready at rtsp://%s:%d/", host, port) - return srv } +func splitNALUnits(frame []byte) [][]byte { + var nalus [][]byte + for { + idx := bytes.Index(frame, []byte{0, 0, 0, 1}) + if idx == -1 { + if len(frame) > 0 { + nalus = append(nalus, frame) + } + return nalus + } + if idx != 0 { + nalus = append(nalus, frame[:idx]) + } + frame = frame[idx+4:] + } +} + +func sendNALUnit(h *Handler, stream *gortsplib.ServerStream, nalu []byte, seq *uint16, ts *uint32) { + const maxPayload = 1200 + + if len(nalu) <= maxPayload { + writeRTP(h, stream, nalu, true, seq, ts) + return + } + + // FU-A fragmentation + first := true + naluHeader := nalu[0] + payload := nalu[1:] + + for len(payload) > 0 { + size := maxPayload + if len(payload) < size { + size = len(payload) + } + + fuHeader := byte(28) | (naluHeader & 0x60) + startFlag := byte(0x80) + endFlag := byte(0x40) + + b := []byte{fuHeader, 0} + b[1] = naluHeader & 0x1F + + if first { + b[1] |= startFlag + first = false + } else if len(payload) <= size { + b[1] |= endFlag + } + + packet := append(b, payload[:size]...) + writeRTP(h, stream, packet, len(payload) <= size, seq, ts) + payload = payload[size:] + } +} + +func writeRTP(h *Handler, stream *gortsplib.ServerStream, payload []byte, marker bool, seq *uint16, ts *uint32) { + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: marker, + PayloadType: 96, + SequenceNumber: *seq, + Timestamp: *ts, + }, + Payload: payload, + } + + stream.WritePacketRTP(stream.Desc.Medias[0], pkt) + *seq++ + *ts += 3600 +} + +// ---- SHUTDOWN ---- + func (s *Server) Stop() { s.cancel() s.server.Close()