diff --git a/apps/api/src/__tests__/snips/crawl.test.ts b/apps/api/src/__tests__/snips/crawl.test.ts index ceb66631..beb8932a 100644 --- a/apps/api/src/__tests__/snips/crawl.test.ts +++ b/apps/api/src/__tests__/snips/crawl.test.ts @@ -1,4 +1,4 @@ -import { crawl } from "./lib"; +import { asyncCrawl, asyncCrawlWaitForFinish, crawl, crawlOngoing } from "./lib"; import { describe, it, expect } from "@jest/globals"; describe("Crawl tests", () => { @@ -45,6 +45,23 @@ describe("Crawl tests", () => { delay: 5, }); }, 300000); + + it.concurrent("ongoing crawls endpoint works", async () => { + const res = await asyncCrawl({ + url: "https://firecrawl.dev", + limit: 3, + }); + + const ongoing = await crawlOngoing(); + + expect(ongoing.ids).toContain(res.id) + + await asyncCrawlWaitForFinish(res.id); + + const ongoing2 = await crawlOngoing(); + + expect(ongoing2.ids).not.toContain(res.id); + }, 120000); // TEMP: Flaky // it.concurrent("discovers URLs properly when origin is not included", async () => { diff --git a/apps/api/src/__tests__/snips/lib.ts b/apps/api/src/__tests__/snips/lib.ts index c8d81838..ea900a5f 100644 --- a/apps/api/src/__tests__/snips/lib.ts +++ b/apps/api/src/__tests__/snips/lib.ts @@ -1,7 +1,7 @@ import { configDotenv } from "dotenv"; configDotenv(); -import { ScrapeRequestInput, Document, ExtractRequestInput, ExtractResponse, CrawlRequestInput, MapRequestInput, BatchScrapeRequestInput, SearchRequestInput, CrawlStatusResponse } from "../../controllers/v1/types"; +import { ScrapeRequestInput, Document, ExtractRequestInput, ExtractResponse, CrawlRequestInput, MapRequestInput, BatchScrapeRequestInput, SearchRequestInput, CrawlStatusResponse, CrawlResponse, OngoingCrawlsResponse, ErrorResponse } from "../../controllers/v1/types"; import request from "supertest"; // ========================================= @@ -90,6 +90,20 @@ async function crawlStatus(id: string) { .send(); } +async function crawlOngoingRaw() { + return await request(TEST_URL) + .get("/v1/crawl/ongoing") + .set("Authorization", `Bearer ${process.env.TEST_API_KEY}`) + .send(); +} + +export async function crawlOngoing(): Promise> { + const res = await crawlOngoingRaw(); + expect(res.statusCode).toBe(200); + expect(res.body.success).toBe(true); + return res.body; +} + function expectCrawlStartToSucceed(response: Awaited>) { expect(response.statusCode).toBe(200); expect(response.body.success).toBe(true); @@ -106,6 +120,25 @@ function expectCrawlToSucceed(response: Awaited>) expect(response.body.data.length).toBeGreaterThan(0); } +export async function asyncCrawl(body: CrawlRequestInput): Promise> { + const cs = await crawlStart(body); + expectCrawlStartToSucceed(cs); + return cs.body; +} + +export async function asyncCrawlWaitForFinish(id: string): Promise> { + let x; + + do { + x = await crawlStatus(id); + expect(x.statusCode).toBe(200); + expect(typeof x.body.status).toBe("string"); + } while (x.body.status === "scraping"); + + expectCrawlToSucceed(x); + return x.body; +} + export async function crawl(body: CrawlRequestInput): Promise { const cs = await crawlStart(body); expectCrawlStartToSucceed(cs); diff --git a/apps/api/src/controllers/v1/crawl-ongoing.ts b/apps/api/src/controllers/v1/crawl-ongoing.ts new file mode 100644 index 00000000..67e08eb1 --- /dev/null +++ b/apps/api/src/controllers/v1/crawl-ongoing.ts @@ -0,0 +1,22 @@ +import { Response } from "express"; +import { + OngoingCrawlsResponse, + RequestWithAuth, +} from "./types"; +import { + getCrawlsByTeamId, +} from "../../lib/crawl-redis"; +import { configDotenv } from "dotenv"; +configDotenv(); + +export async function ongoingCrawlsController( + req: RequestWithAuth<{}, undefined, OngoingCrawlsResponse>, + res: Response, +) { + const ids = await getCrawlsByTeamId(req.auth.team_id); + + res.status(200).json({ + success: true, + ids, + }); +} diff --git a/apps/api/src/controllers/v1/types.ts b/apps/api/src/controllers/v1/types.ts index 5ee1328d..f3713e89 100644 --- a/apps/api/src/controllers/v1/types.ts +++ b/apps/api/src/controllers/v1/types.ts @@ -875,6 +875,13 @@ export type CrawlStatusResponse = data: Document[]; }; +export type OngoingCrawlsResponse = + | ErrorResponse + | { + success: true; + ids: string[]; + }; + export type CrawlErrorsResponse = | ErrorResponse | { diff --git a/apps/api/src/lib/crawl-redis.ts b/apps/api/src/lib/crawl-redis.ts index 0984a628..277b0d73 100644 --- a/apps/api/src/lib/crawl-redis.ts +++ b/apps/api/src/lib/crawl-redis.ts @@ -26,6 +26,13 @@ export async function saveCrawl(id: string, crawl: StoredCrawl) { }); await redisEvictConnection.set("crawl:" + id, JSON.stringify(crawl)); await redisEvictConnection.expire("crawl:" + id, 24 * 60 * 60); + + await redisEvictConnection.sadd("crawls_by_team_id:" + crawl.team_id, id); + await redisEvictConnection.expire("crawls_by_team_id:" + crawl.team_id, 24 * 60 * 60); +} + +export async function getCrawlsByTeamId(team_id: string): Promise { + return await redisEvictConnection.smembers("crawls_by_team_id:" + team_id); } export async function getCrawl(id: string): Promise { @@ -183,6 +190,12 @@ export async function finishCrawl(id: string) { }); await redisEvictConnection.set("crawl:" + id + ":finish", "yes"); await redisEvictConnection.expire("crawl:" + id + ":finish", 24 * 60 * 60); + + const crawl = await getCrawl(id); + if (crawl && crawl.team_id) { + await redisEvictConnection.srem("crawls_by_team_id:" + crawl.team_id, id); + await redisEvictConnection.expire("crawls_by_team_id:" + crawl.team_id, 24 * 60 * 60); + } } export async function getCrawlJobs(id: string): Promise { diff --git a/apps/api/src/routes/v1.ts b/apps/api/src/routes/v1.ts index 686653bc..373a11b4 100644 --- a/apps/api/src/routes/v1.ts +++ b/apps/api/src/routes/v1.ts @@ -35,6 +35,7 @@ import { generateLLMsTextStatusController } from "../controllers/v1/generate-llm import { deepResearchController } from "../controllers/v1/deep-research"; import { deepResearchStatusController } from "../controllers/v1/deep-research-status"; import { tokenUsageController } from "../controllers/v1/token-usage"; +import { ongoingCrawlsController } from "../controllers/v1/crawl-ongoing"; function checkCreditsMiddleware( minimum?: number, @@ -219,6 +220,12 @@ v1Router.get( wrap(crawlStatusController), ); +v1Router.get( + "/crawl/ongoing", + authMiddleware(RateLimiterMode.CrawlStatus), + wrap(ongoingCrawlsController), +); + v1Router.get( "/batch/scrape/:jobId", authMiddleware(RateLimiterMode.CrawlStatus),