import Foundation import UIKit import Display import Vision import CoreImage import CoreImage.CIFilterBuiltins import VideoToolbox import SwiftSignalKit import TelegramCore import AccountContext import FileMediaResourceStatus import ZipArchive import ImageTransparency private let queue = Queue() public enum CutoutAvailability { case available case progress(Float) case unavailable } private var forceCoreMLVariant: Bool { #if targetEnvironment(simulator) return true #else return false #endif } private func modelPath() -> String { return NSTemporaryDirectory() + "u2netp.mlmodelc" } public func cutoutAvailability(context: AccountContext) -> Signal { if #available(iOS 17.0, *), !forceCoreMLVariant { return .single(.available) } else if #available(iOS 14.0, *) { let compiledModelPath = modelPath() if FileManager.default.fileExists(atPath: compiledModelPath) { return .single(.available) } return context.engine.peers.resolvePeerByName(name: "stickersbackgroundseparation", referrer: nil) |> mapToSignal { result -> Signal in guard case let .result(maybePeer) = result else { return .complete() } guard let peer = maybePeer else { return .single(.unavailable) } return context.account.viewTracker.aroundMessageHistoryViewForLocation(.peer(peerId: peer.id, threadId: nil), index: .lowerBound, anchorIndex: .lowerBound, count: 5, fixedCombinedReadStates: nil) |> mapToSignal { view -> Signal<(TelegramMediaFile, EngineMessage)?, NoError> in if !view.0.isLoading { if let message = view.0.entries.last?.message, let file = message.media.first(where: { $0 is TelegramMediaFile }) as? TelegramMediaFile { return .single((file, EngineMessage(message))) } else { return .single(nil) } } else { return .complete() } } |> take(1) |> mapToSignal { maybeFileAndMessage -> Signal in if let (file, message) = maybeFileAndMessage { let fetchedData = context.engine.resources.fetch(reference: FileMediaReference.message(message: MessageReference(message._asMessage()), media: file).resourceReference(file.resource), userLocation: .other, userContentType: .file) enum FetchStatus { case completed(String) case progress(Float) case failed } let fetchStatus = Signal { subscriber in let fetchedDisposable = fetchedData.start() let resourceDataDisposable = context.engine.resources.data(resource: EngineMediaResource(file.resource)).start(next: { next in if next.isComplete { SSZipArchive.unzipFile(atPath: next.path, toDestination: NSTemporaryDirectory()) subscriber.putNext(.completed(compiledModelPath)) subscriber.putCompletion() } }, error: subscriber.putError, completed: subscriber.putCompletion) let progressDisposable = messageFileMediaResourceStatus(context: context, file: file, message: message, isRecentActions: false).start(next: { status in switch status.fetchStatus { case let .Remote(progress), let .Fetching(_, progress), let .Paused(progress): subscriber.putNext(.progress(progress)) default: break } }) return ActionDisposable { fetchedDisposable.dispose() resourceDataDisposable.dispose() progressDisposable.dispose() } } return fetchStatus |> mapToSignal { status -> Signal in switch status { case .completed: return .single(.available) case let .progress(progress): return .single(.progress(progress)) case .failed: return .single(.unavailable) } } } else { return .single(.unavailable) } } } } else { return .single(.unavailable) } } public func cutoutStickerImage(from image: UIImage, context: AccountContext? = nil, onlyCheck: Bool = false) -> Signal { guard let cgImage = image.cgImage else { return .single(nil) } if #available(iOS 17.0, *), !forceCoreMLVariant { return Signal { subscriber in let ciContext = CIContext(options: nil) let inputImage = CIImage(cgImage: cgImage) let handler = VNImageRequestHandler(cgImage: cgImage, options: [:]) let request = VNGenerateForegroundInstanceMaskRequest { [weak handler] request, error in guard let handler, let result = request.results?.first as? VNInstanceMaskObservation else { subscriber.putNext(nil) subscriber.putCompletion() return } if onlyCheck { subscriber.putNext(UIImage()) subscriber.putCompletion() } else { let instances = instances(atPoint: nil, inObservation: result) if let mask = try? result.generateScaledMaskForImage(forInstances: instances, from: handler) { let filter = CIFilter.blendWithMask() filter.inputImage = inputImage filter.backgroundImage = CIImage(color: .clear) filter.maskImage = CIImage(cvPixelBuffer: mask) if let output = filter.outputImage, let cgImage = ciContext.createCGImage(output, from: inputImage.extent) { let image = UIImage(cgImage: cgImage) subscriber.putNext(image) subscriber.putCompletion() return } } subscriber.putNext(nil) subscriber.putCompletion() } } try? handler.perform([request]) return ActionDisposable { request.cancel() } } |> runOn(queue) } else if #available(iOS 14.0, *), onlyCheck { return Signal { subscriber in U2netp.load(contentsOf: URL(fileURLWithPath: modelPath()), completionHandler: { result in switch result { case let .success(model): let modelImageSize = CGSize(width: 320, height: 320) if let squareImage = scaleImageToPixelSize(image: image, size: modelImageSize), let pixelBuffer = buffer(from: squareImage), let result = try? model.prediction(in_0: pixelBuffer), let resultImage = UIImage(pixelBuffer: result.out_p1), imageHasSubject(resultImage) { subscriber.putNext(UIImage()) } else { subscriber.putNext(nil) } subscriber.putCompletion() case .failure: subscriber.putNext(nil) subscriber.putCompletion() } }) return EmptyDisposable } |> runOn(queue) } else { return .single(nil) } } public struct CutoutResult { public enum Image { case image(UIImage, CIImage) case pixelBuffer(CVPixelBuffer) } public let index: Int public let extractedImage: Image? public let edgesMaskImage: Image? public let maskImage: Image? public let backgroundImage: Image? } public enum CutoutTarget { case point(CGPoint?) case index(Int) case all } func refineEdges(_ maskImage: CIImage) -> CIImage? { let maskImage = maskImage.clampedToExtent() let blurFilter = CIFilter(name: "CIGaussianBlur")! blurFilter.setValue(maskImage, forKey: kCIInputImageKey) blurFilter.setValue(11.4, forKey: kCIInputRadiusKey) let controlsFilter = CIFilter(name: "CIColorControls")! controlsFilter.setValue(blurFilter.outputImage, forKey: kCIInputImageKey) controlsFilter.setValue(6.61, forKey: kCIInputContrastKey) let sharpenFilter = CIFilter(name: "CISharpenLuminance")! sharpenFilter.setValue(controlsFilter.outputImage, forKey: kCIInputImageKey) sharpenFilter.setValue(250.0, forKey: kCIInputSharpnessKey) return sharpenFilter.outputImage?.cropped(to: maskImage.extent) } public func cutoutImage( from image: UIImage, editedImage: UIImage? = nil, crop: (offset: CGPoint, rotation: CGFloat, scale: CGFloat)?, target: CutoutTarget, includeExtracted: Bool = true, completion: @escaping ([CutoutResult]) -> Void ) { guard #available(iOS 14.0, *), let cgImage = image.cgImage else { completion([]) return } let ciContext = CIContext(options: nil) let inputImage = CIImage(cgImage: cgImage) var results: [CutoutResult] = [] func process(instance: Int, mask originalMaskImage: CIImage) { let extractedImage: CutoutResult.Image? if includeExtracted { let filter = CIFilter.blendWithMask() filter.backgroundImage = CIImage(color: .clear) let dimensions: CGSize var maskImage = originalMaskImage if let editedImage = editedImage?.cgImage.flatMap({ CIImage(cgImage: $0) }) { filter.inputImage = editedImage dimensions = editedImage.extent.size if let (cropOffset, cropRotation, cropScale) = crop { let initialScale: CGFloat if maskImage.extent.height > maskImage.extent.width { initialScale = dimensions.width / maskImage.extent.width } else { initialScale = dimensions.width / maskImage.extent.height } let dimensions = editedImage.extent.size maskImage = maskImage.transformed(by: CGAffineTransform(translationX: -maskImage.extent.width / 2.0, y: -maskImage.extent.height / 2.0)) var transform = CGAffineTransform.identity transform = transform.translatedBy(x: dimensions.width / 2.0 + cropOffset.x, y: dimensions.height / 2.0 + cropOffset.y * -1.0) transform = transform.rotated(by: -cropRotation) transform = transform.scaledBy(x: cropScale * initialScale, y: cropScale * initialScale) maskImage = maskImage.transformed(by: transform) } } else { filter.inputImage = inputImage dimensions = inputImage.extent.size } filter.maskImage = maskImage if let output = filter.outputImage, let cgImage = ciContext.createCGImage(output, from: CGRect(origin: .zero, size: dimensions)) { extractedImage = .image(UIImage(cgImage: cgImage), output) } else { extractedImage = nil } } else { extractedImage = nil } let whiteImage = CIImage(color: .white) let blackImage = CIImage(color: .black) let maskFilter = CIFilter.blendWithMask() maskFilter.inputImage = whiteImage maskFilter.backgroundImage = blackImage maskFilter.maskImage = originalMaskImage let refinedMaskFilter = CIFilter.blendWithMask() refinedMaskFilter.inputImage = whiteImage refinedMaskFilter.backgroundImage = blackImage refinedMaskFilter.maskImage = refineEdges(originalMaskImage) let edgesMaskImage: CutoutResult.Image? let maskImage: CutoutResult.Image? if let maskOutput = maskFilter.outputImage?.cropped(to: inputImage.extent), let maskCgImage = ciContext.createCGImage(maskOutput, from: inputImage.extent), let refinedMaskOutput = refinedMaskFilter.outputImage?.cropped(to: inputImage.extent), let refinedMaskCgImage = ciContext.createCGImage(refinedMaskOutput, from: inputImage.extent) { edgesMaskImage = .image(UIImage(cgImage: maskCgImage), maskOutput) maskImage = .image(UIImage(cgImage: refinedMaskCgImage), refinedMaskOutput) } else { edgesMaskImage = nil maskImage = nil } if extractedImage != nil || maskImage != nil { results.append(CutoutResult(index: instance, extractedImage: extractedImage, edgesMaskImage: edgesMaskImage, maskImage: maskImage, backgroundImage: nil)) } } if #available(iOS 17.0, *), !forceCoreMLVariant { queue.async { let handler = VNImageRequestHandler(cgImage: cgImage, options: [:]) let request = VNGenerateForegroundInstanceMaskRequest { [weak handler] request, error in guard let handler, let result = request.results?.first as? VNInstanceMaskObservation else { completion([]) return } let targetInstances: IndexSet switch target { case let .point(point): targetInstances = instances(atPoint: point, inObservation: result) case let .index(index): targetInstances = IndexSet([index]) case .all: targetInstances = result.allInstances } for instance in targetInstances { if let mask = try? result.generateScaledMaskForImage(forInstances: IndexSet(integer: instance), from: handler) { process(instance: instance, mask: CIImage(cvPixelBuffer: mask)) } } completion(results) } try? handler.perform([request]) } } else { U2netp.load(contentsOf: URL(fileURLWithPath: modelPath()), completionHandler: { result in switch result { case let .success(model): let modelImageSize = CGSize(width: 320, height: 320) if let squareImage = scaleImageToPixelSize(image: image, size: modelImageSize), let pixelBuffer = buffer(from: squareImage), let result = try? model.prediction(in_0: pixelBuffer), let maskImage = UIImage(pixelBuffer: result.out_p1), let scaledMaskImage = scaleImageToPixelSize(image: maskImage, size: image.size), let ciImage = CIImage(image: scaledMaskImage) { process(instance: 0, mask: ciImage) } case .failure: break } completion(results) }) } } @available(iOS 17.0, *) private func instances(atPoint maybePoint: CGPoint?, inObservation observation: VNInstanceMaskObservation) -> IndexSet { guard let point = maybePoint else { return observation.allInstances } let instanceMap = observation.instanceMask let coords = VNImagePointForNormalizedPoint(point, CVPixelBufferGetWidth(instanceMap) - 1, CVPixelBufferGetHeight(instanceMap) - 1) CVPixelBufferLockBaseAddress(instanceMap, .readOnly) guard let pixels = CVPixelBufferGetBaseAddress(instanceMap) else { fatalError() } let bytesPerRow = CVPixelBufferGetBytesPerRow(instanceMap) let instanceLabel = pixels.load(fromByteOffset: Int(coords.y) * bytesPerRow + Int(coords.x), as: UInt8.self) CVPixelBufferUnlockBaseAddress(instanceMap, .readOnly) return instanceLabel == 0 ? observation.allInstances : [Int(instanceLabel)] } private extension UIImage { convenience init?(pixelBuffer: CVPixelBuffer) { var cgImage: CGImage? VTCreateCGImageFromCVPixelBuffer(pixelBuffer, options: nil, imageOut: &cgImage) guard let cgImage = cgImage else { return nil } self.init(cgImage: cgImage) } } private func scaleImageToPixelSize(image: UIImage, size: CGSize) -> UIImage? { UIGraphicsBeginImageContextWithOptions(size, true, 1.0) image.draw(in: CGRect(origin: CGPoint(), size: size), blendMode: .copy, alpha: 1.0) let result = UIGraphicsGetImageFromCurrentImageContext() UIGraphicsEndImageContext() return result } private func buffer(from image: UIImage) -> CVPixelBuffer? { let attrs = [kCVPixelBufferCGImageCompatibilityKey: kCFBooleanTrue, kCVPixelBufferCGBitmapContextCompatibilityKey: kCFBooleanTrue] as CFDictionary var pixelBuffer : CVPixelBuffer? let status = CVPixelBufferCreate(kCFAllocatorDefault, Int(image.size.width), Int(image.size.height), kCVPixelFormatType_32ARGB, attrs, &pixelBuffer) guard (status == kCVReturnSuccess) else { return nil } guard let pixelBufferUnwrapped = pixelBuffer else { return nil } CVPixelBufferLockBaseAddress(pixelBufferUnwrapped, CVPixelBufferLockFlags(rawValue: 0)) let pixelData = CVPixelBufferGetBaseAddress(pixelBufferUnwrapped) let rgbColorSpace = CGColorSpaceCreateDeviceRGB() guard let context = CGContext(data: pixelData, width: Int(image.size.width), height: Int(image.size.height), bitsPerComponent: 8, bytesPerRow: CVPixelBufferGetBytesPerRow(pixelBufferUnwrapped), space: rgbColorSpace, bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue) else { return nil } context.translateBy(x: 0, y: image.size.height) context.scaleBy(x: 1.0, y: -1.0) UIGraphicsPushContext(context) image.draw(in: CGRect(x: 0, y: 0, width: image.size.width, height: image.size.height)) UIGraphicsPopContext() CVPixelBufferUnlockBaseAddress(pixelBufferUnwrapped, CVPixelBufferLockFlags(rawValue: 0)) return pixelBufferUnwrapped }